From b94b7f24b393dd35181c536c0bc603ba5603d1e1 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 15 Aug 2025 09:36:25 -0700 Subject: [PATCH 01/54] wip --- .../ClusterAwareWriterFailoverHandler.java | 17 +- .../failover/FailoverConnectionPlugin.java | 39 +- .../FailoverConnectionPluginFactory.java | 16 +- .../jdbc/util/CoreServicesContainer.java | 2 +- ...ClusterAwareWriterFailoverHandlerTest.java | 816 ++++++++-------- .../FailoverConnectionPluginTest.java | 884 +++++++++--------- 6 files changed, 909 insertions(+), 865 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 47f741b5e..07d04d4f9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -35,9 +35,11 @@ import software.amazon.jdbc.PluginService; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionService; /** * An implementation of WriterFailoverHandler. @@ -56,28 +58,33 @@ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler protected int reconnectWriterIntervalMs = 5000; // 5 sec protected Properties initialConnectionProps; protected PluginService pluginService; + protected ConnectionService connectionService; protected ReaderFailoverHandler readerFailoverHandler; private static final WriterFailoverResult DEFAULT_RESULT = new WriterFailoverResult(false, false, null, null, "None"); public ClusterAwareWriterFailoverHandler( - final PluginService pluginService, + final FullServicesContainer servicesContainer, + final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps) { - this.pluginService = pluginService; + this.pluginService = servicesContainer.getPluginService(); + this.connectionService = connectionService; this.readerFailoverHandler = readerFailoverHandler; this.initialConnectionProps = initialConnectionProps; } public ClusterAwareWriterFailoverHandler( - final PluginService pluginService, + final FullServicesContainer servicesContainer, + final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps, final int failoverTimeoutMs, final int readTopologyIntervalMs, final int reconnectWriterIntervalMs) { this( - pluginService, + servicesContainer, + connectionService, readerFailoverHandler, initialConnectionProps); this.maxFailoverTimeoutMs = failoverTimeoutMs; @@ -260,7 +267,7 @@ public WriterFailoverResult call() { // TODO: assess whether multi-threaded access to the plugin service is safe. The same plugin service is used // by both the ConnectionWrapper and this ReconnectToWriterHandler in separate threads. - conn = pluginService.forceConnect(this.originalWriterHost, initialConnectionProps); + conn = connectionService.open(this.originalWriterHost, initialConnectionProps); pluginService.forceRefreshHostList(conn); latestTopology = pluginService.getAllHosts(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 8c16484ea..1807a2f8c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -33,6 +33,8 @@ import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.DriverConnectionProvider; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; @@ -42,16 +44,20 @@ import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.TargetDriverHelper; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsHelper; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionService; +import software.amazon.jdbc.util.connection.ConnectionServiceImpl; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -92,6 +98,8 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { private final Set subscribedMethods; private final PluginService pluginService; + private final FullServicesContainer servicesContainer; + private final ConnectionService connectionService; protected final Properties properties; protected boolean enableFailoverSetting; protected boolean enableConnectFailover; @@ -185,15 +193,16 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { PropertyDefinition.registerPluginProperties(FailoverConnectionPlugin.class); } - public FailoverConnectionPlugin(final PluginService pluginService, final Properties properties) { - this(pluginService, properties, new RdsUtils()); + public FailoverConnectionPlugin(final FullServicesContainer servicesContainer, final Properties properties) { + this(servicesContainer, properties, new RdsUtils()); } FailoverConnectionPlugin( - final PluginService pluginService, + final FullServicesContainer servicesContainer, final Properties properties, final RdsUtils rdsHelper) { - this.pluginService = pluginService; + this.servicesContainer = servicesContainer; + this.pluginService = servicesContainer.getPluginService(); this.properties = properties; this.rdsHelper = rdsHelper; @@ -222,6 +231,25 @@ public FailoverConnectionPlugin(final PluginService pluginService, final Propert } this.subscribedMethods = Collections.unmodifiableSet(methods); + try { + TargetDriverHelper helper = new TargetDriverHelper(); + java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); + final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); + this.connectionService = new ConnectionServiceImpl( + servicesContainer.getStorageService(), + servicesContainer.getMonitorService(), + servicesContainer.getTelemetryFactory(), + defaultConnectionProvider, + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect(), + properties + ); + } catch (SQLException e) { + throw new RuntimeException(e); + } + TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); this.failoverWriterTriggeredCounter = telemetryFactory.createCounter("writerFailover.triggered.count"); this.failoverWriterSuccessCounter = telemetryFactory.createCounter("writerFailover.completed.success.count"); @@ -316,7 +344,8 @@ public void initHostProvider( this.failoverMode == FailoverMode.STRICT_READER), () -> new ClusterAwareWriterFailoverHandler( - this.pluginService, + this.servicesContainer, + this.connectionService, this.readerFailoverHandler, this.properties, this.failoverTimeoutMsSetting, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginFactory.java index 75d445710..9fd77a767 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginFactory.java @@ -18,13 +18,21 @@ import java.util.Properties; import software.amazon.jdbc.ConnectionPlugin; -import software.amazon.jdbc.ConnectionPluginFactory; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.ServicesContainerPluginFactory; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.Messages; -public class FailoverConnectionPluginFactory implements ConnectionPluginFactory { - +public class FailoverConnectionPluginFactory implements ServicesContainerPluginFactory { @Override public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { - return new FailoverConnectionPlugin(pluginService, props); + throw new UnsupportedOperationException( + Messages.get( + "ServicesContainerPluginFactory.servicesContainerRequired", new Object[] {"FailoverConnectionPlugin"})); + } + + @Override + public ConnectionPlugin getInstance(final FullServicesContainer servicesContainer, final Properties props) { + return new FailoverConnectionPlugin(servicesContainer, props); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/CoreServicesContainer.java b/wrapper/src/main/java/software/amazon/jdbc/util/CoreServicesContainer.java index 3ac3e0a1b..1c01d5bbd 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/CoreServicesContainer.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/CoreServicesContainer.java @@ -33,8 +33,8 @@ public class CoreServicesContainer { private static final CoreServicesContainer INSTANCE = new CoreServicesContainer(); - private final StorageService storageService; private final MonitorService monitorService; + private final StorageService storageService; private CoreServicesContainer() { EventPublisher eventPublisher = new BatchingEventPublisher(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java index 0944be13c..5f302209a 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java @@ -1,408 +1,408 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.refEq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentMatchers; -import org.mockito.InOrder; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; - -class ClusterAwareWriterFailoverHandlerTest { - - @Mock PluginService mockPluginService; - @Mock Connection mockConnection; - @Mock ReaderFailoverHandler mockReaderFailover; - @Mock Connection mockWriterConnection; - @Mock Connection mockNewWriterConnection; - @Mock Connection mockReaderAConnection; - @Mock Connection mockReaderBConnection; - @Mock Dialect mockDialect; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("new-writer-host").build(); - private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer-host").build(); - private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-a-host").build(); - private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-b-host").build(); - private final List topology = Arrays.asList(writer, readerA, readerB); - private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - writer.addAlias("writer-host"); - newWriterHost.addAlias("new-writer-host"); - readerA.addAlias("reader-a-host"); - readerB.addAlias("reader-b-host"); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testReconnectToWriter_taskBReaderException() throws SQLException { - when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockConnection); - when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenThrow(SQLException.class); - when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = - new ClusterAwareWriterFailoverHandler( - mockPluginService, - mockReaderFailover, - properties, - 5000, - 2000, - 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockConnection); - - final InOrder inOrder = Mockito.inOrder(mockPluginService); - inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: successfully re-connect to initial writer; return new connection. - * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_SlowReaderA() throws SQLException { - when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); - when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); - - when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return new ReaderFailoverResult(mockReaderAConnection, readerA, true); - }); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = - new ClusterAwareWriterFailoverHandler( - mockPluginService, - mockReaderFailover, - properties, - 60000, - 5000, - 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - - final InOrder inOrder = Mockito.inOrder(mockPluginService); - inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes. - * TaskA: successfully re-connect to writer; return new connection. - * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_taskBDefers() throws SQLException { - when(mockPluginService.forceConnect(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = - new ClusterAwareWriterFailoverHandler( - mockPluginService, - mockReaderFailover, - properties, - 60000, - 2000, - 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - - final InOrder inOrder = Mockito.inOrder(mockPluginService); - inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. - * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more - * time than taskB. - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_SlowWriter() throws SQLException { - when(mockPluginService.forceConnect(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = - new ClusterAwareWriterFailoverHandler( - mockPluginService, - mockReaderFailover, - properties, - 60000, - 5000, - 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(3, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - - verify(mockPluginService, times(1)).setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.AVAILABLE)); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. - * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_taskADefers() throws SQLException { - when(mockPluginService.forceConnect(writer, properties)).thenReturn(mockConnection); - when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockNewWriterConnection; - }); - - final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = - new ClusterAwareWriterFailoverHandler( - mockPluginService, - mockReaderFailover, - properties, - 60000, - 5000, - 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(4, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - verify(mockPluginService, times(1)).setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.AVAILABLE)); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to failover timeout. - * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_failoverTimeout() throws SQLException { - when(mockPluginService.forceConnect(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockWriterConnection; - }); - when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockNewWriterConnection; - }); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = - new ClusterAwareWriterFailoverHandler( - mockPluginService, - mockReaderFailover, - properties, - 5000, - 2000, - 2000); - - final long startTimeNano = System.nanoTime(); - final WriterFailoverResult result = target.failover(topology); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to exception. - * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenThrow(exception); - when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenThrow(exception); - when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = - new ClusterAwareWriterFailoverHandler( - mockPluginService, - mockReaderFailover, - properties, - 5000, - 2000, - 2000); - final WriterFailoverResult result = target.failover(topology); - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - verify(mockPluginService, atLeastOnce()) - .setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.NOT_AVAILABLE)); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.ArgumentMatchers.refEq; +// import static org.mockito.Mockito.atLeastOnce; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.Arrays; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.ArgumentMatchers; +// import org.mockito.InOrder; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// +// class ClusterAwareWriterFailoverHandlerTest { +// +// @Mock PluginService mockPluginService; +// @Mock Connection mockConnection; +// @Mock ReaderFailoverHandler mockReaderFailover; +// @Mock Connection mockWriterConnection; +// @Mock Connection mockNewWriterConnection; +// @Mock Connection mockReaderAConnection; +// @Mock Connection mockReaderBConnection; +// @Mock Dialect mockDialect; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("new-writer-host").build(); +// private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer-host").build(); +// private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-a-host").build(); +// private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-b-host").build(); +// private final List topology = Arrays.asList(writer, readerA, readerB); +// private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// writer.addAlias("writer-host"); +// newWriterHost.addAlias("new-writer-host"); +// readerA.addAlias("reader-a-host"); +// readerB.addAlias("reader-b-host"); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testReconnectToWriter_taskBReaderException() throws SQLException { +// when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockConnection); +// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenThrow(SQLException.class); +// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = +// new ClusterAwareWriterFailoverHandler( +// mockPluginService, +// mockReaderFailover, +// properties, +// 5000, +// 2000, +// 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockConnection); +// +// final InOrder inOrder = Mockito.inOrder(mockPluginService); +// inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: successfully re-connect to initial writer; return new connection. +// * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_SlowReaderA() throws SQLException { +// when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); +// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); +// +// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return new ReaderFailoverResult(mockReaderAConnection, readerA, true); +// }); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = +// new ClusterAwareWriterFailoverHandler( +// mockPluginService, +// mockReaderFailover, +// properties, +// 60000, +// 5000, +// 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// +// final InOrder inOrder = Mockito.inOrder(mockPluginService); +// inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes. +// * TaskA: successfully re-connect to writer; return new connection. +// * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_taskBDefers() throws SQLException { +// when(mockPluginService.forceConnect(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = +// new ClusterAwareWriterFailoverHandler( +// mockPluginService, +// mockReaderFailover, +// properties, +// 60000, +// 2000, +// 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// +// final InOrder inOrder = Mockito.inOrder(mockPluginService); +// inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. +// * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more +// * time than taskB. +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_SlowWriter() throws SQLException { +// when(mockPluginService.forceConnect(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = +// new ClusterAwareWriterFailoverHandler( +// mockPluginService, +// mockReaderFailover, +// properties, +// 60000, +// 5000, +// 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(3, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// +// verify(mockPluginService, times(1)).setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.AVAILABLE)); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. +// * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_taskADefers() throws SQLException { +// when(mockPluginService.forceConnect(writer, properties)).thenReturn(mockConnection); +// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockNewWriterConnection; +// }); +// +// final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = +// new ClusterAwareWriterFailoverHandler( +// mockPluginService, +// mockReaderFailover, +// properties, +// 60000, +// 5000, +// 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(4, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// verify(mockPluginService, times(1)).setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.AVAILABLE)); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to failover timeout. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_failoverTimeout() throws SQLException { +// when(mockPluginService.forceConnect(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockWriterConnection; +// }); +// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockNewWriterConnection; +// }); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = +// new ClusterAwareWriterFailoverHandler( +// mockPluginService, +// mockReaderFailover, +// properties, +// 5000, +// 2000, +// 2000); +// +// final long startTimeNano = System.nanoTime(); +// final WriterFailoverResult result = target.failover(topology); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to exception. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenThrow(exception); +// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenThrow(exception); +// when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = +// new ClusterAwareWriterFailoverHandler( +// mockPluginService, +// mockReaderFailover, +// properties, +// 5000, +// 2000, +// 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// verify(mockPluginService, atLeastOnce()) +// .setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.NOT_AVAILABLE)); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index 97ab5f119..33160333f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -1,442 +1,442 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.NodeChangeOptions; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.RdsUrlType; -import software.amazon.jdbc.util.SqlState; -import software.amazon.jdbc.util.telemetry.GaugeCallable; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; - -class FailoverConnectionPluginTest { - - private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; - private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; - private static final Object[] EMPTY_ARGS = {}; - private final List defaultHosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build()); - - @Mock PluginService mockPluginService; - @Mock Connection mockConnection; - @Mock HostSpec mockHostSpec; - @Mock HostListProviderService mockHostListProviderService; - @Mock AuroraHostListProvider mockHostListProvider; - @Mock JdbcCallable mockInitHostProviderFunc; - @Mock ReaderFailoverHandler mockReaderFailoverHandler; - @Mock WriterFailoverHandler mockWriterFailoverHandler; - @Mock ReaderFailoverResult mockReaderResult; - @Mock WriterFailoverResult mockWriterResult; - @Mock JdbcCallable mockSqlFunction; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock TelemetryContext mockTelemetryContext; - @Mock TelemetryCounter mockTelemetryCounter; - @Mock TelemetryGauge mockTelemetryGauge; - @Mock TargetDriverDialect mockTargetDriverDialect; - - - private final Properties properties = new Properties(); - private FailoverConnectionPlugin plugin; - private AutoCloseable closeable; - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - } - - @BeforeEach - void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - - when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); - when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); - when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockPluginService.getHosts()).thenReturn(defaultHosts); - when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); - when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); - when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); - when(mockWriterResult.isConnected()).thenReturn(true); - when(mockWriterResult.getTopology()).thenReturn(defaultHosts); - when(mockReaderResult.isConnected()).thenReturn(true); - - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); - // noinspection unchecked - when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - - when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); - when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); - - properties.clear(); - } - - @Test - void test_notifyNodeListChanged_withFailoverDisabled() { - properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); - final Map> changes = new HashMap<>(); - - initializePlugin(); - plugin.notifyNodeListChanged(changes); - - verify(mockPluginService, never()).getCurrentHostSpec(); - verify(mockHostSpec, never()).getAliases(); - } - - @Test - void test_notifyNodeListChanged_withValidConnectionNotInTopology() { - final Map> changes = new HashMap<>(); - changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); - changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); - - initializePlugin(); - plugin.notifyNodeListChanged(changes); - - when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); - - verify(mockPluginService).getCurrentHostSpec(); - verify(mockHostSpec, never()).getAliases(); - } - - @Test - void test_updateTopology() throws SQLException { - initializePlugin(); - - // Test updateTopology with failover disabled - plugin.setRdsUrlType(RdsUrlType.RDS_PROXY); - plugin.updateTopology(false); - verify(mockPluginService, never()).forceRefreshHostList(); - verify(mockPluginService, never()).refreshHostList(); - - // Test updateTopology with no connection - when(mockPluginService.getCurrentHostSpec()).thenReturn(null); - plugin.updateTopology(false); - verify(mockPluginService, never()).forceRefreshHostList(); - verify(mockPluginService, never()).refreshHostList(); - - // Test updateTopology with closed connection - when(mockConnection.isClosed()).thenReturn(true); - plugin.updateTopology(false); - verify(mockPluginService, never()).forceRefreshHostList(); - verify(mockPluginService, never()).refreshHostList(); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { - - when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); - when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); - when(mockConnection.isClosed()).thenReturn(false); - initializePlugin(); - plugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); - - plugin.updateTopology(forceUpdate); - if (forceUpdate) { - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); - } else { - verify(mockPluginService, atLeastOnce()).refreshHostList(); - } - } - - @Test - void test_failover_failoverWriter() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(true); - - initializePlugin(); - final FailoverConnectionPlugin spyPlugin = spy(plugin); - doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); - spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; - - assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); - verify(spyPlugin).failoverWriter(); - } - - @Test - void test_failover_failoverReader() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(false); - - initializePlugin(); - final FailoverConnectionPlugin spyPlugin = spy(plugin); - doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); - spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; - - assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); - verify(spyPlugin).failoverReader(eq(mockHostSpec)); - } - - @Test - void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); - when(mockReaderResult.isConnected()).thenReturn(true); - when(mockReaderResult.getConnection()).thenReturn(mockConnection); - when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); - - initializePlugin(); - plugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - () -> mockReaderFailoverHandler, - () -> mockWriterFailoverHandler); - - final FailoverConnectionPlugin spyPlugin = spy(plugin); - doNothing().when(spyPlugin).updateTopology(true); - - assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); - - verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); - verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); - } - - @Test - void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { - final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") - .build(); - final List hosts = Collections.singletonList(hostSpec); - - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); - when(mockPluginService.getAllHosts()).thenReturn(hosts); - when(mockPluginService.getHosts()).thenReturn(hosts); - when(mockReaderResult.getException()).thenReturn(new SQLException()); - when(mockReaderResult.getHost()).thenReturn(hostSpec); - - initializePlugin(); - plugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - () -> mockReaderFailoverHandler, - () -> mockWriterFailoverHandler); - - assertThrows(SQLException.class, () -> plugin.failoverReader(null)); - verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); - } - - @Test - void test_failoverWriter_failedFailover_throwsException() throws SQLException { - final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") - .build(); - final List hosts = Collections.singletonList(hostSpec); - - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockPluginService.getAllHosts()).thenReturn(hosts); - when(mockPluginService.getHosts()).thenReturn(hosts); - when(mockWriterResult.getException()).thenReturn(new SQLException()); - - initializePlugin(); - plugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - () -> mockReaderFailoverHandler, - () -> mockWriterFailoverHandler); - - assertThrows(SQLException.class, () -> plugin.failoverWriter()); - verify(mockWriterFailoverHandler).failover(eq(hosts)); - } - - @Test - void test_failoverWriter_failedFailover_withNoResult() throws SQLException { - final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") - .build(); - final List hosts = Collections.singletonList(hostSpec); - - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockPluginService.getAllHosts()).thenReturn(hosts); - when(mockPluginService.getHosts()).thenReturn(hosts); - when(mockWriterResult.isConnected()).thenReturn(false); - - initializePlugin(); - plugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - () -> mockReaderFailoverHandler, - () -> mockWriterFailoverHandler); - - final SQLException exception = assertThrows(SQLException.class, () -> plugin.failoverWriter()); - assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); - - verify(mockWriterFailoverHandler).failover(eq(hosts)); - verify(mockWriterResult, never()).getNewConnection(); - verify(mockWriterResult, never()).getTopology(); - } - - @Test - void test_failoverWriter_successFailover() throws SQLException { - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - - initializePlugin(); - plugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - () -> mockReaderFailoverHandler, - () -> mockWriterFailoverHandler); - - final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> plugin.failoverWriter()); - assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); - - verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); - } - - @Test - void test_invalidCurrentConnection_withNoConnection() { - when(mockPluginService.getCurrentConnection()).thenReturn(null); - initializePlugin(); - plugin.invalidateCurrentConnection(); - - verify(mockPluginService, never()).getCurrentHostSpec(); - } - - @Test - void test_invalidateCurrentConnection_inTransaction() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(true); - when(mockHostSpec.getHost()).thenReturn("host"); - when(mockHostSpec.getPort()).thenReturn(123); - when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - - initializePlugin(); - plugin.invalidateCurrentConnection(); - verify(mockConnection).rollback(); - - // Assert SQL exceptions thrown during rollback do not get propagated. - doThrow(new SQLException()).when(mockConnection).rollback(); - assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); - } - - @Test - void test_invalidateCurrentConnection_notInTransaction() { - when(mockPluginService.isInTransaction()).thenReturn(false); - when(mockHostSpec.getHost()).thenReturn("host"); - when(mockHostSpec.getPort()).thenReturn(123); - when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - - initializePlugin(); - plugin.invalidateCurrentConnection(); - - verify(mockPluginService).isInTransaction(); - } - - @Test - void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(false); - when(mockConnection.isClosed()).thenReturn(false); - when(mockHostSpec.getHost()).thenReturn("host"); - when(mockHostSpec.getPort()).thenReturn(123); - when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - - initializePlugin(); - plugin.invalidateCurrentConnection(); - - doThrow(new SQLException()).when(mockConnection).close(); - assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); - - verify(mockConnection, times(2)).isClosed(); - verify(mockConnection, times(2)).close(); - } - - @Test - void test_execute_withFailoverDisabled() throws SQLException { - properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); - initializePlugin(); - - plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - mockSqlFunction, - EMPTY_ARGS); - - verify(mockSqlFunction).call(); - verify(mockHostListProvider, never()).getRdsUrlType(); - } - - @Test - void test_execute_withDirectExecute() throws SQLException { - initializePlugin(); - plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - "close", - mockSqlFunction, - EMPTY_ARGS); - verify(mockSqlFunction).call(); - verify(mockHostListProvider, never()).getRdsUrlType(); - } - - private void initializePlugin() { - plugin = new FailoverConnectionPlugin(mockPluginService, properties); - plugin.setWriterFailoverHandler(mockWriterFailoverHandler); - plugin.setReaderFailoverHandler(mockReaderFailoverHandler); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.atLeastOnce; +// import static org.mockito.Mockito.doNothing; +// import static org.mockito.Mockito.doThrow; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.HashMap; +// import java.util.HashSet; +// import java.util.List; +// import java.util.Map; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.ValueSource; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.NodeChangeOptions; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.RdsUrlType; +// import software.amazon.jdbc.util.SqlState; +// import software.amazon.jdbc.util.telemetry.GaugeCallable; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.util.telemetry.TelemetryGauge; +// +// class FailoverConnectionPluginTest { +// +// private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; +// private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; +// private static final Object[] EMPTY_ARGS = {}; +// private final List defaultHosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build()); +// +// @Mock PluginService mockPluginService; +// @Mock Connection mockConnection; +// @Mock HostSpec mockHostSpec; +// @Mock HostListProviderService mockHostListProviderService; +// @Mock AuroraHostListProvider mockHostListProvider; +// @Mock JdbcCallable mockInitHostProviderFunc; +// @Mock ReaderFailoverHandler mockReaderFailoverHandler; +// @Mock WriterFailoverHandler mockWriterFailoverHandler; +// @Mock ReaderFailoverResult mockReaderResult; +// @Mock WriterFailoverResult mockWriterResult; +// @Mock JdbcCallable mockSqlFunction; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock TelemetryCounter mockTelemetryCounter; +// @Mock TelemetryGauge mockTelemetryGauge; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// +// +// private final Properties properties = new Properties(); +// private FailoverConnectionPlugin plugin; +// private AutoCloseable closeable; +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @BeforeEach +// void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// +// when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); +// when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); +// when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockPluginService.getHosts()).thenReturn(defaultHosts); +// when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); +// when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); +// when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); +// when(mockWriterResult.isConnected()).thenReturn(true); +// when(mockWriterResult.getTopology()).thenReturn(defaultHosts); +// when(mockReaderResult.isConnected()).thenReturn(true); +// +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); +// // noinspection unchecked +// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); +// +// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); +// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); +// +// properties.clear(); +// } +// +// @Test +// void test_notifyNodeListChanged_withFailoverDisabled() { +// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); +// final Map> changes = new HashMap<>(); +// +// initializePlugin(); +// plugin.notifyNodeListChanged(changes); +// +// verify(mockPluginService, never()).getCurrentHostSpec(); +// verify(mockHostSpec, never()).getAliases(); +// } +// +// @Test +// void test_notifyNodeListChanged_withValidConnectionNotInTopology() { +// final Map> changes = new HashMap<>(); +// changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); +// changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); +// +// initializePlugin(); +// plugin.notifyNodeListChanged(changes); +// +// when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); +// +// verify(mockPluginService).getCurrentHostSpec(); +// verify(mockHostSpec, never()).getAliases(); +// } +// +// @Test +// void test_updateTopology() throws SQLException { +// initializePlugin(); +// +// // Test updateTopology with failover disabled +// plugin.setRdsUrlType(RdsUrlType.RDS_PROXY); +// plugin.updateTopology(false); +// verify(mockPluginService, never()).forceRefreshHostList(); +// verify(mockPluginService, never()).refreshHostList(); +// +// // Test updateTopology with no connection +// when(mockPluginService.getCurrentHostSpec()).thenReturn(null); +// plugin.updateTopology(false); +// verify(mockPluginService, never()).forceRefreshHostList(); +// verify(mockPluginService, never()).refreshHostList(); +// +// // Test updateTopology with closed connection +// when(mockConnection.isClosed()).thenReturn(true); +// plugin.updateTopology(false); +// verify(mockPluginService, never()).forceRefreshHostList(); +// verify(mockPluginService, never()).refreshHostList(); +// } +// +// @ParameterizedTest +// @ValueSource(booleans = {true, false}) +// void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { +// +// when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); +// when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); +// when(mockConnection.isClosed()).thenReturn(false); +// initializePlugin(); +// plugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); +// +// plugin.updateTopology(forceUpdate); +// if (forceUpdate) { +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); +// } else { +// verify(mockPluginService, atLeastOnce()).refreshHostList(); +// } +// } +// +// @Test +// void test_failover_failoverWriter() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(true); +// +// initializePlugin(); +// final FailoverConnectionPlugin spyPlugin = spy(plugin); +// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); +// spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; +// +// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); +// verify(spyPlugin).failoverWriter(); +// } +// +// @Test +// void test_failover_failoverReader() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(false); +// +// initializePlugin(); +// final FailoverConnectionPlugin spyPlugin = spy(plugin); +// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); +// spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; +// +// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); +// verify(spyPlugin).failoverReader(eq(mockHostSpec)); +// } +// +// @Test +// void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); +// when(mockReaderResult.isConnected()).thenReturn(true); +// when(mockReaderResult.getConnection()).thenReturn(mockConnection); +// when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); +// +// initializePlugin(); +// plugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// () -> mockReaderFailoverHandler, +// () -> mockWriterFailoverHandler); +// +// final FailoverConnectionPlugin spyPlugin = spy(plugin); +// doNothing().when(spyPlugin).updateTopology(true); +// +// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); +// +// verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); +// verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); +// } +// +// @Test +// void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { +// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") +// .build(); +// final List hosts = Collections.singletonList(hostSpec); +// +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); +// when(mockPluginService.getAllHosts()).thenReturn(hosts); +// when(mockPluginService.getHosts()).thenReturn(hosts); +// when(mockReaderResult.getException()).thenReturn(new SQLException()); +// when(mockReaderResult.getHost()).thenReturn(hostSpec); +// +// initializePlugin(); +// plugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// () -> mockReaderFailoverHandler, +// () -> mockWriterFailoverHandler); +// +// assertThrows(SQLException.class, () -> plugin.failoverReader(null)); +// verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); +// } +// +// @Test +// void test_failoverWriter_failedFailover_throwsException() throws SQLException { +// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") +// .build(); +// final List hosts = Collections.singletonList(hostSpec); +// +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockPluginService.getAllHosts()).thenReturn(hosts); +// when(mockPluginService.getHosts()).thenReturn(hosts); +// when(mockWriterResult.getException()).thenReturn(new SQLException()); +// +// initializePlugin(); +// plugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// () -> mockReaderFailoverHandler, +// () -> mockWriterFailoverHandler); +// +// assertThrows(SQLException.class, () -> plugin.failoverWriter()); +// verify(mockWriterFailoverHandler).failover(eq(hosts)); +// } +// +// @Test +// void test_failoverWriter_failedFailover_withNoResult() throws SQLException { +// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") +// .build(); +// final List hosts = Collections.singletonList(hostSpec); +// +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockPluginService.getAllHosts()).thenReturn(hosts); +// when(mockPluginService.getHosts()).thenReturn(hosts); +// when(mockWriterResult.isConnected()).thenReturn(false); +// +// initializePlugin(); +// plugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// () -> mockReaderFailoverHandler, +// () -> mockWriterFailoverHandler); +// +// final SQLException exception = assertThrows(SQLException.class, () -> plugin.failoverWriter()); +// assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); +// +// verify(mockWriterFailoverHandler).failover(eq(hosts)); +// verify(mockWriterResult, never()).getNewConnection(); +// verify(mockWriterResult, never()).getTopology(); +// } +// +// @Test +// void test_failoverWriter_successFailover() throws SQLException { +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// +// initializePlugin(); +// plugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// () -> mockReaderFailoverHandler, +// () -> mockWriterFailoverHandler); +// +// final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> plugin.failoverWriter()); +// assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); +// +// verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); +// } +// +// @Test +// void test_invalidCurrentConnection_withNoConnection() { +// when(mockPluginService.getCurrentConnection()).thenReturn(null); +// initializePlugin(); +// plugin.invalidateCurrentConnection(); +// +// verify(mockPluginService, never()).getCurrentHostSpec(); +// } +// +// @Test +// void test_invalidateCurrentConnection_inTransaction() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(true); +// when(mockHostSpec.getHost()).thenReturn("host"); +// when(mockHostSpec.getPort()).thenReturn(123); +// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); +// +// initializePlugin(); +// plugin.invalidateCurrentConnection(); +// verify(mockConnection).rollback(); +// +// // Assert SQL exceptions thrown during rollback do not get propagated. +// doThrow(new SQLException()).when(mockConnection).rollback(); +// assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); +// } +// +// @Test +// void test_invalidateCurrentConnection_notInTransaction() { +// when(mockPluginService.isInTransaction()).thenReturn(false); +// when(mockHostSpec.getHost()).thenReturn("host"); +// when(mockHostSpec.getPort()).thenReturn(123); +// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); +// +// initializePlugin(); +// plugin.invalidateCurrentConnection(); +// +// verify(mockPluginService).isInTransaction(); +// } +// +// @Test +// void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(false); +// when(mockConnection.isClosed()).thenReturn(false); +// when(mockHostSpec.getHost()).thenReturn("host"); +// when(mockHostSpec.getPort()).thenReturn(123); +// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); +// +// initializePlugin(); +// plugin.invalidateCurrentConnection(); +// +// doThrow(new SQLException()).when(mockConnection).close(); +// assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); +// +// verify(mockConnection, times(2)).isClosed(); +// verify(mockConnection, times(2)).close(); +// } +// +// @Test +// void test_execute_withFailoverDisabled() throws SQLException { +// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); +// initializePlugin(); +// +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// mockSqlFunction, +// EMPTY_ARGS); +// +// verify(mockSqlFunction).call(); +// verify(mockHostListProvider, never()).getRdsUrlType(); +// } +// +// @Test +// void test_execute_withDirectExecute() throws SQLException { +// initializePlugin(); +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// "close", +// mockSqlFunction, +// EMPTY_ARGS); +// verify(mockSqlFunction).call(); +// verify(mockHostListProvider, never()).getRdsUrlType(); +// } +// +// private void initializePlugin() { +// plugin = new FailoverConnectionPlugin(mockPluginService, properties); +// plugin.setWriterFailoverHandler(mockWriterFailoverHandler); +// plugin.setReaderFailoverHandler(mockReaderFailoverHandler); +// } +// } From 29a58cbffea247fc0dd782378b8652900dab1e6f Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 15 Aug 2025 14:13:12 -0700 Subject: [PATCH 02/54] Replace PluginService#forceConnect with ConnectionService#open in ClusterAwareWriterFailoverHandler --- .../ClusterAwareWriterFailoverHandler.java | 97 ++++++++++++------- 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 07d04d4f9..3972ada6a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -141,7 +141,7 @@ public WriterFailoverResult failover(final List currentTopology) } } - private HostSpec getWriter(final List topology) { + private static HostSpec getWriter(final List topology) { if (topology == null || topology.isEmpty()) { return null; } @@ -158,13 +158,22 @@ private void submitTasks( final List currentTopology, final ExecutorService executorService, final CompletionService completionService, final boolean singleTask) { - final HostSpec writerHost = this.getWriter(currentTopology); + final HostSpec writerHost = getWriter(currentTopology); if (!singleTask) { - completionService.submit(new ReconnectToWriterHandler(writerHost)); + completionService.submit( + new ReconnectToWriterHandler( + this.connectionService, writerHost, this.initialConnectionProps, this.reconnectWriterIntervalMs)); } - completionService.submit(new WaitForNewWriterHandler( - currentTopology, - writerHost)); + + completionService.submit( + new WaitForNewWriterHandler( + this.connectionService, + this.readerFailoverHandler, + writerHost, + this.initialConnectionProps, + this.readTopologyIntervalMs, + currentTopology)); + executorService.shutdown(); } @@ -240,19 +249,30 @@ private SQLException createInterruptedException(final InterruptedException e) { /** * Internal class responsible for re-connecting to the current writer (aka TaskA). */ - private class ReconnectToWriterHandler implements Callable { + private static class ReconnectToWriterHandler implements Callable { + private final ConnectionService connectionService; private final HostSpec originalWriterHost; - - public ReconnectToWriterHandler(final HostSpec originalWriterHost) { + private final Properties props; + private final int reconnectWriterIntervalMs; + private PluginService pluginService = null; + + public ReconnectToWriterHandler( + final ConnectionService connectionService, + final HostSpec originalWriterHost, + final Properties props, + final int reconnectWriterIntervalMs) { + this.connectionService = connectionService; this.originalWriterHost = originalWriterHost; + this.props = props; + this.reconnectWriterIntervalMs = reconnectWriterIntervalMs; } public WriterFailoverResult call() { LOGGER.fine( () -> Messages.get( "ClusterAwareWriterFailoverHandler.taskAAttemptReconnectToWriterInstance", - new Object[] {this.originalWriterHost.getUrl(), PropertyUtils.maskProperties(initialConnectionProps)})); + new Object[] {this.originalWriterHost.getUrl(), PropertyUtils.maskProperties(this.props)})); Connection conn = null; List latestTopology = null; @@ -265,15 +285,14 @@ public WriterFailoverResult call() { conn.close(); } - // TODO: assess whether multi-threaded access to the plugin service is safe. The same plugin service is used - // by both the ConnectionWrapper and this ReconnectToWriterHandler in separate threads. - conn = connectionService.open(this.originalWriterHost, initialConnectionProps); - pluginService.forceRefreshHostList(conn); - latestTopology = pluginService.getAllHosts(); - + conn = connectionService.open(this.originalWriterHost, this.props); + this.pluginService = conn.unwrap(PluginService.class); + this.pluginService.forceRefreshHostList(conn); + latestTopology = this.pluginService.getAllHosts(); } catch (final SQLException exception) { // Propagate exceptions that are not caused by network errors. - if (!pluginService.isNetworkException(exception, pluginService.getTargetDriverDialect())) { + if (this.pluginService != null + && !pluginService.isNetworkException(exception, pluginService.getTargetDriverDialect())) { LOGGER.finer( () -> Messages.get( "ClusterAwareWriterFailoverHandler.taskAEncounteredException", @@ -323,26 +342,39 @@ private boolean isCurrentHostWriter(final List latestTopology) { * Internal class responsible for getting the latest cluster topology and connecting to a newly * elected writer (aka TaskB). */ - private class WaitForNewWriterHandler implements Callable { + private static class WaitForNewWriterHandler implements Callable { private Connection currentConnection = null; + private PluginService pluginService = null; + private final ConnectionService connectionService; + private final ReaderFailoverHandler readerFailoverHandler; private final HostSpec originalWriterHost; + private final Properties props; + private final int readTopologyIntervalMs; private List currentTopology; private HostSpec currentReaderHost; private Connection currentReaderConnection; public WaitForNewWriterHandler( - final List currentTopology, - final HostSpec currentHost) { + final ConnectionService connectionService, + final ReaderFailoverHandler readerFailoverHandler, + final HostSpec originalWriterHost, + final Properties props, + final int readTopologyIntervalMs, + final List currentTopology) { + this.connectionService = connectionService; + this.readerFailoverHandler = readerFailoverHandler; + this.originalWriterHost = originalWriterHost; + this.props = props; + this.readTopologyIntervalMs = readTopologyIntervalMs; this.currentTopology = currentTopology; - this.originalWriterHost = currentHost; } public WriterFailoverResult call() { LOGGER.finer( () -> Messages.get( "ClusterAwareWriterFailoverHandler.taskBAttemptConnectionToNewWriterInstance", - new Object[] {PropertyUtils.maskProperties(initialConnectionProps)})); + new Object[] {PropertyUtils.maskProperties(this.props)})); try { boolean success = false; @@ -381,6 +413,7 @@ private void connectToReader() throws InterruptedException { if (isValidReaderConnection(connResult)) { this.currentReaderConnection = connResult.getConnection(); this.currentReaderHost = connResult.getHost(); + this.pluginService = this.currentReaderConnection.unwrap(PluginService.class); LOGGER.fine( () -> Messages.get( "ClusterAwareWriterFailoverHandler.taskBConnectedToReader", @@ -396,10 +429,7 @@ private void connectToReader() throws InterruptedException { } private boolean isValidReaderConnection(final ReaderFailoverResult result) { - if (!result.isConnected() || result.getConnection() == null || result.getHost() == null) { - return false; - } - return true; + return result.isConnected() && result.getConnection() != null && result.getHost() != null; } /** @@ -408,14 +438,14 @@ private boolean isValidReaderConnection(final ReaderFailoverResult result) { * @return Returns true if successful. */ private boolean refreshTopologyAndConnectToNewWriter() throws InterruptedException { - boolean allowOldWriter = pluginService.getDialect() + boolean allowOldWriter = this.pluginService.getDialect() .getFailoverRestrictions() .contains(FailoverRestriction.ENABLE_WRITER_IN_TASK_B); while (true) { try { - pluginService.forceRefreshHostList(this.currentReaderConnection); - final List topology = pluginService.getAllHosts(); + this.pluginService.forceRefreshHostList(this.currentReaderConnection); + final List topology = this.pluginService.getAllHosts(); if (!topology.isEmpty()) { @@ -474,13 +504,12 @@ private boolean connectToWriter(final HostSpec writerCandidate) { new Object[] {writerCandidate.getUrl()})); try { // connect to the new writer - // TODO: assess whether multi-threaded access to the plugin service is safe. The same plugin service is used - // by both the ConnectionWrapper and this WaitForNewWriterHandler in separate threads. - this.currentConnection = pluginService.forceConnect(writerCandidate, initialConnectionProps); - pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.AVAILABLE); + this.currentConnection = this.connectionService.open(writerCandidate, this.props); + this.pluginService = this.currentConnection.unwrap(PluginService.class); + this.pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.AVAILABLE); return true; } catch (final SQLException exception) { - pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.NOT_AVAILABLE); + this.pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.NOT_AVAILABLE); return false; } } From f59930d5235c5af9e13bb26a6a8070bbe439a7b9 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 15 Aug 2025 16:05:58 -0700 Subject: [PATCH 03/54] Add notes for how to fix the current issues --- .../jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java | 2 ++ .../amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 3972ada6a..1e07b80b8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -444,6 +444,7 @@ private boolean refreshTopologyAndConnectToNewWriter() throws InterruptedExcepti while (true) { try { + // TODO: replace with host list provider this.pluginService.forceRefreshHostList(this.currentReaderConnection); final List topology = this.pluginService.getAllHosts(); @@ -506,6 +507,7 @@ private boolean connectToWriter(final HostSpec writerCandidate) { // connect to the new writer this.currentConnection = this.connectionService.open(writerCandidate, this.props); this.pluginService = this.currentConnection.unwrap(PluginService.class); + // TODO: replace with a map that is shared between the two handlers this.pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.AVAILABLE); return true; } catch (final SQLException exception) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 1807a2f8c..0e22b20de 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -639,6 +639,8 @@ protected void dealWithIllegalStateException( protected void failover(final HostSpec failedHost) throws SQLException { this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); + // TODO: instantiate ConnectionService here + // After failover, retrieve the unavailable hosts from the handlers after replacing pluginService#setAvailability if (this.failoverMode == FailoverMode.STRICT_WRITER) { failoverWriter(); } else { From 8fc6d1a4fe83173c4c43555354cb1ef6358e576b Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 18 Aug 2025 16:56:02 -0700 Subject: [PATCH 04/54] Integ test passing --- .../ClusterAwareWriterFailoverHandler.java | 61 +++++++++++++------ .../failover/FailoverConnectionPlugin.java | 47 +++++++------- .../failover/WriterFailoverHandler.java | 4 +- 3 files changed, 68 insertions(+), 44 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 1e07b80b8..72d2a5ad3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -32,7 +32,9 @@ import java.util.logging.Logger; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.HostListProviderSupplier; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; @@ -57,26 +59,24 @@ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler protected int readTopologyIntervalMs = 5000; // 5 sec protected int reconnectWriterIntervalMs = 5000; // 5 sec protected Properties initialConnectionProps; + protected FullServicesContainer servicesContainer; protected PluginService pluginService; - protected ConnectionService connectionService; protected ReaderFailoverHandler readerFailoverHandler; private static final WriterFailoverResult DEFAULT_RESULT = new WriterFailoverResult(false, false, null, null, "None"); public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, - final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps) { + this.servicesContainer = servicesContainer; this.pluginService = servicesContainer.getPluginService(); - this.connectionService = connectionService; this.readerFailoverHandler = readerFailoverHandler; this.initialConnectionProps = initialConnectionProps; } public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, - final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps, final int failoverTimeoutMs, @@ -84,7 +84,6 @@ public ClusterAwareWriterFailoverHandler( final int reconnectWriterIntervalMs) { this( servicesContainer, - connectionService, readerFailoverHandler, initialConnectionProps); this.maxFailoverTimeoutMs = failoverTimeoutMs; @@ -99,7 +98,7 @@ public ClusterAwareWriterFailoverHandler( * @return {@link WriterFailoverResult} The results of this process. */ @Override - public WriterFailoverResult failover(final List currentTopology) + public WriterFailoverResult failover(final ConnectionService connectionService, final List currentTopology) throws SQLException { if (Utils.isNullOrEmpty(currentTopology)) { LOGGER.severe(() -> Messages.get("ClusterAwareWriterFailoverHandler.failoverCalledWithInvalidTopology")); @@ -112,7 +111,7 @@ public WriterFailoverResult failover(final List currentTopology) final ExecutorService executorService = ExecutorFactory.newFixedThreadPool(2, "failover"); final CompletionService completionService = new ExecutorCompletionService<>(executorService); - submitTasks(currentTopology, executorService, completionService, singleTask); + submitTasks(connectionService, currentTopology, executorService, completionService, singleTask); try { final long startTimeNano = System.nanoTime(); @@ -155,19 +154,26 @@ private static HostSpec getWriter(final List topology) { } private void submitTasks( - final List currentTopology, final ExecutorService executorService, + final ConnectionService connectionService, + final List currentTopology, + final ExecutorService executorService, final CompletionService completionService, final boolean singleTask) { final HostSpec writerHost = getWriter(currentTopology); if (!singleTask) { completionService.submit( new ReconnectToWriterHandler( - this.connectionService, writerHost, this.initialConnectionProps, this.reconnectWriterIntervalMs)); + connectionService, + this.getNewPluginService(), + writerHost, + this.initialConnectionProps, + this.reconnectWriterIntervalMs)); } completionService.submit( new WaitForNewWriterHandler( - this.connectionService, + connectionService, + this.getNewPluginService(), this.readerFailoverHandler, writerHost, this.initialConnectionProps, @@ -177,6 +183,24 @@ private void submitTasks( executorService.shutdown(); } + private PluginService getNewPluginService() { + PartialPluginService partialPluginService = new PartialPluginService( + this.servicesContainer, + this.initialConnectionProps, + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect() + ); + + // TODO: can we clean this up, eg move to PartialPluginService constructor? + final HostListProviderSupplier supplier = this.pluginService.getDialect().getHostListProvider(); + partialPluginService.setHostListProvider( + supplier.getProvider(this.initialConnectionProps, this.pluginService.getOriginalUrl(), this.servicesContainer)); + + return partialPluginService; + } + private WriterFailoverResult getNextResult( final ExecutorService executorService, final CompletionService completionService, @@ -255,14 +279,16 @@ private static class ReconnectToWriterHandler implements Callable Messages.get( "ClusterAwareWriterFailoverHandler.taskAEncounteredException", @@ -344,25 +368,27 @@ private boolean isCurrentHostWriter(final List latestTopology) { */ private static class WaitForNewWriterHandler implements Callable { - private Connection currentConnection = null; - private PluginService pluginService = null; private final ConnectionService connectionService; + private final PluginService pluginService; private final ReaderFailoverHandler readerFailoverHandler; private final HostSpec originalWriterHost; private final Properties props; private final int readTopologyIntervalMs; + private Connection currentConnection = null; private List currentTopology; private HostSpec currentReaderHost; private Connection currentReaderConnection; public WaitForNewWriterHandler( final ConnectionService connectionService, + final PluginService pluginService, final ReaderFailoverHandler readerFailoverHandler, final HostSpec originalWriterHost, final Properties props, final int readTopologyIntervalMs, final List currentTopology) { this.connectionService = connectionService; + this.pluginService = pluginService; this.readerFailoverHandler = readerFailoverHandler; this.originalWriterHost = originalWriterHost; this.props = props; @@ -413,7 +439,6 @@ private void connectToReader() throws InterruptedException { if (isValidReaderConnection(connResult)) { this.currentReaderConnection = connResult.getConnection(); this.currentReaderHost = connResult.getHost(); - this.pluginService = this.currentReaderConnection.unwrap(PluginService.class); LOGGER.fine( () -> Messages.get( "ClusterAwareWriterFailoverHandler.taskBConnectedToReader", @@ -444,7 +469,6 @@ private boolean refreshTopologyAndConnectToNewWriter() throws InterruptedExcepti while (true) { try { - // TODO: replace with host list provider this.pluginService.forceRefreshHostList(this.currentReaderConnection); final List topology = this.pluginService.getAllHosts(); @@ -506,7 +530,6 @@ private boolean connectToWriter(final HostSpec writerCandidate) { try { // connect to the new writer this.currentConnection = this.connectionService.open(writerCandidate, this.props); - this.pluginService = this.currentConnection.unwrap(PluginService.class); // TODO: replace with a map that is shared between the two handlers this.pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.AVAILABLE); return true; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 0e22b20de..817e5750d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -99,7 +99,7 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { private final Set subscribedMethods; private final PluginService pluginService; private final FullServicesContainer servicesContainer; - private final ConnectionService connectionService; + private ConnectionService connectionService; protected final Properties properties; protected boolean enableFailoverSetting; protected boolean enableConnectFailover; @@ -231,25 +231,6 @@ public FailoverConnectionPlugin(final FullServicesContainer servicesContainer, f } this.subscribedMethods = Collections.unmodifiableSet(methods); - try { - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); - this.connectionService = new ConnectionServiceImpl( - servicesContainer.getStorageService(), - servicesContainer.getMonitorService(), - servicesContainer.getTelemetryFactory(), - defaultConnectionProvider, - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - properties - ); - } catch (SQLException e) { - throw new RuntimeException(e); - } - TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); this.failoverWriterTriggeredCounter = telemetryFactory.createCounter("writerFailover.triggered.count"); this.failoverWriterSuccessCounter = telemetryFactory.createCounter("writerFailover.completed.success.count"); @@ -345,7 +326,6 @@ public void initHostProvider( () -> new ClusterAwareWriterFailoverHandler( this.servicesContainer, - this.connectionService, this.readerFailoverHandler, this.properties, this.failoverTimeoutMsSetting, @@ -639,8 +619,26 @@ protected void dealWithIllegalStateException( protected void failover(final HostSpec failedHost) throws SQLException { this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); - // TODO: instantiate ConnectionService here - // After failover, retrieve the unavailable hosts from the handlers after replacing pluginService#setAvailability + + // TODO: After failover, retrieve the unavailable hosts from the handlers after replacing + // pluginService#setAvailability + TargetDriverHelper helper = new TargetDriverHelper(); + java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); + final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); + if (this.connectionService == null) { + this.connectionService = new ConnectionServiceImpl( + servicesContainer.getStorageService(), + servicesContainer.getMonitorService(), + servicesContainer.getTelemetryFactory(), + defaultConnectionProvider, + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect(), + properties + ); + } + if (this.failoverMode == FailoverMode.STRICT_WRITER) { failoverWriter(); } else { @@ -751,7 +749,8 @@ protected void failoverWriter() throws SQLException { try { LOGGER.info(() -> Messages.get("Failover.startWriterFailover")); - final WriterFailoverResult failoverResult = this.writerFailoverHandler.failover(this.pluginService.getAllHosts()); + final WriterFailoverResult failoverResult = + this.writerFailoverHandler.failover(this.connectionService, this.pluginService.getAllHosts()); if (failoverResult != null) { final SQLException exception = failoverResult.getException(); if (exception != null) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java index caae8de8c..68e69259b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java @@ -19,6 +19,7 @@ import java.sql.SQLException; import java.util.List; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.util.connection.ConnectionService; /** * Interface for Writer Failover Process handler. This handler implements all necessary logic to try @@ -33,5 +34,6 @@ public interface WriterFailoverHandler { * @return {@link WriterFailoverResult} The results of this process. * @throws SQLException indicating whether the failover attempt was successful. */ - WriterFailoverResult failover(List currentTopology) throws SQLException; + WriterFailoverResult failover( + ConnectionService connectionService, List currentTopology) throws SQLException; } From b4e5b3aa10d473e8a342ad63c5f8f836386210ad Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 19 Aug 2025 16:18:33 -0700 Subject: [PATCH 05/54] Pass ConnectionService to ReaderFailoverHandler --- .../plugin/failover/ClusterAwareWriterFailoverHandler.java | 3 ++- .../jdbc/plugin/failover/FailoverConnectionPlugin.java | 3 ++- .../amazon/jdbc/plugin/failover/ReaderFailoverHandler.java | 7 +++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 72d2a5ad3..81098ba65 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -435,7 +435,8 @@ public WriterFailoverResult call() { private void connectToReader() throws InterruptedException { while (true) { try { - final ReaderFailoverResult connResult = readerFailoverHandler.getReaderConnection(this.currentTopology); + final ReaderFailoverResult connResult = + readerFailoverHandler.getReaderConnection(this.connectionService, this.currentTopology); if (isValidReaderConnection(connResult)) { this.currentReaderConnection = connResult.getConnection(); this.currentReaderHost = connResult.getHost(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 817e5750d..c684dbb50 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -665,7 +665,8 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException failedHost = failedHostSpec; } - final ReaderFailoverResult result = readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); + final ReaderFailoverResult result = readerFailoverHandler.failover( + this.connectionService, this.pluginService.getHosts(), failedHost); if (result != null) { final SQLException exception = result.getException(); if (exception != null) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java index e006558b6..008c1e17e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java @@ -19,6 +19,7 @@ import java.sql.SQLException; import java.util.List; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.util.connection.ConnectionService; /** * Interface for Reader Failover Process handler. This handler implements all necessary logic to try @@ -36,7 +37,8 @@ public interface ReaderFailoverHandler { * @return {@link ReaderFailoverResult} The results of this process. * @throws SQLException indicating whether the failover attempt was successful. */ - ReaderFailoverResult failover(List hosts, HostSpec currentHost) throws SQLException; + ReaderFailoverResult failover( + ConnectionService connectionService, List hosts, HostSpec currentHost) throws SQLException; /** * Called to get any available reader connection. If no reader is available then result of process @@ -46,5 +48,6 @@ public interface ReaderFailoverHandler { * @return {@link ReaderFailoverResult} The results of this process. * @throws SQLException if any error occurred while attempting a reader connection. */ - ReaderFailoverResult getReaderConnection(List hostList) throws SQLException; + ReaderFailoverResult getReaderConnection( + ConnectionService connectionService, List hostList) throws SQLException; } From a2342befd179e363d893c191a28ccd9d0ecba19d Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 19 Aug 2025 16:46:54 -0700 Subject: [PATCH 06/54] Adapted ReaderFailoverHandler, IT passing --- .../ClusterAwareReaderFailoverHandler.java | 111 ++- .../failover/FailoverConnectionPlugin.java | 2 +- ...ClusterAwareReaderFailoverHandlerTest.java | 814 +++++++++--------- 3 files changed, 485 insertions(+), 442 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index f97a2ed4d..18df01f3c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -33,12 +33,16 @@ import java.util.logging.Logger; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.HostListProviderSupplier; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionService; /** * An implementation of ReaderFailoverHandler. @@ -58,24 +62,25 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler protected static final int DEFAULT_READER_CONNECT_TIMEOUT = 30000; // 30 sec public static final ReaderFailoverResult FAILED_READER_FAILOVER_RESULT = new ReaderFailoverResult(null, null, false); - protected Properties initialConnectionProps; + protected Properties props; protected int maxFailoverTimeoutMs; protected int timeoutMs; protected boolean isStrictReaderRequired; + protected final FullServicesContainer servicesContainer; protected final PluginService pluginService; /** * ClusterAwareReaderFailoverHandler constructor. * - * @param pluginService A provider for creating new connections. - * @param initialConnectionProps The initial connection properties to copy over to the new reader. + * @param servicesContainer A provider for creating new connections. + * @param props The initial connection properties to copy over to the new reader. */ public ClusterAwareReaderFailoverHandler( - final PluginService pluginService, - final Properties initialConnectionProps) { + final FullServicesContainer servicesContainer, + final Properties props) { this( - pluginService, - initialConnectionProps, + servicesContainer, + props, DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, false); @@ -84,21 +89,22 @@ public ClusterAwareReaderFailoverHandler( /** * ClusterAwareReaderFailoverHandler constructor. * - * @param pluginService A provider for creating new connections. - * @param initialConnectionProps The initial connection properties to copy over to the new reader. + * @param servicesContainer A provider for creating new connections. + * @param props The initial connection properties to copy over to the new reader. * @param maxFailoverTimeoutMs Maximum allowed time for the entire reader failover process. * @param timeoutMs Maximum allowed time in milliseconds for each reader connection attempt during * the reader failover process. * @param isStrictReaderRequired When true, it disables adding a writer to a list of nodes to connect */ public ClusterAwareReaderFailoverHandler( - final PluginService pluginService, - final Properties initialConnectionProps, + final FullServicesContainer servicesContainer, + final Properties props, final int maxFailoverTimeoutMs, final int timeoutMs, final boolean isStrictReaderRequired) { - this.pluginService = pluginService; - this.initialConnectionProps = initialConnectionProps; + this.servicesContainer = servicesContainer; + this.pluginService = servicesContainer.getPluginService(); + this.props = props; this.maxFailoverTimeoutMs = maxFailoverTimeoutMs; this.timeoutMs = timeoutMs; this.isStrictReaderRequired = isStrictReaderRequired; @@ -124,7 +130,8 @@ protected void setTimeoutMs(final int timeoutMs) { * @return {@link ReaderFailoverResult} The results of this process. */ @Override - public ReaderFailoverResult failover(final List hosts, final HostSpec currentHost) + public ReaderFailoverResult failover( + final ConnectionService connectionService, final List hosts, final HostSpec currentHost) throws SQLException { if (Utils.isNullOrEmpty(hosts)) { LOGGER.fine(() -> Messages.get("ClusterAwareReaderFailoverHandler.invalidTopology", new Object[] {"failover"})); @@ -133,11 +140,13 @@ public ReaderFailoverResult failover(final List hosts, final HostSpec final ExecutorService executor = ExecutorFactory.newSingleThreadExecutor("failover"); - final Future future = submitInternalFailoverTask(hosts, currentHost, executor); + final Future future = + submitInternalFailoverTask(connectionService, hosts, currentHost, executor); return getInternalFailoverResult(executor, future); } private Future submitInternalFailoverTask( + final ConnectionService connectionService, final List hosts, final HostSpec currentHost, final ExecutorService executor) { @@ -145,7 +154,7 @@ private Future submitInternalFailoverTask( ReaderFailoverResult result; try { while (true) { - result = failoverInternal(hosts, currentHost); + result = failoverInternal(connectionService, hosts, currentHost); if (result != null && result.isConnected()) { return result; } @@ -190,6 +199,7 @@ private ReaderFailoverResult getInternalFailoverResult( } protected ReaderFailoverResult failoverInternal( + final ConnectionService connectionService, final List hosts, final HostSpec currentHost) throws SQLException { @@ -197,7 +207,7 @@ protected ReaderFailoverResult failoverInternal( this.pluginService.setAvailability(currentHost.asAliases(), HostAvailability.NOT_AVAILABLE); } final List hostsByPriority = getHostsByPriority(hosts); - return getConnectionFromHostGroup(hostsByPriority); + return getConnectionFromHostGroup(connectionService, hostsByPriority); } public List getHostsByPriority(final List hosts) { @@ -239,7 +249,8 @@ public List getHostsByPriority(final List hosts) { * @return {@link ReaderFailoverResult} The results of this process. */ @Override - public ReaderFailoverResult getReaderConnection(final List hostList) + public ReaderFailoverResult getReaderConnection( + final ConnectionService connectionService, final List hostList) throws SQLException { if (Utils.isNullOrEmpty(hostList)) { LOGGER.fine( @@ -250,7 +261,7 @@ public ReaderFailoverResult getReaderConnection(final List hostList) } final List hostsByPriority = getReaderHostsByPriority(hostList); - return getConnectionFromHostGroup(hostsByPriority); + return getConnectionFromHostGroup(connectionService, hostsByPriority); } public List getReaderHostsByPriority(final List hosts) { @@ -291,7 +302,8 @@ public List getReaderHostsByPriority(final List hosts) { return hostsByPriority; } - private ReaderFailoverResult getConnectionFromHostGroup(final List hosts) + private ReaderFailoverResult getConnectionFromHostGroup( + final ConnectionService connectionService, final List hosts) throws SQLException { final ExecutorService executor = ExecutorFactory.newFixedThreadPool(2, "failover"); @@ -300,7 +312,8 @@ private ReaderFailoverResult getConnectionFromHostGroup(final List hos try { for (int i = 0; i < hosts.size(); i += 2) { // submit connection attempt tasks in batches of 2 - final ReaderFailoverResult result = getResultFromNextTaskBatch(hosts, executor, completionService, i); + final ReaderFailoverResult result = + getResultFromNextTaskBatch(connectionService, hosts, executor, completionService, i); if (result.isConnected() || result.getException() != null) { return result; } @@ -323,15 +336,19 @@ private ReaderFailoverResult getConnectionFromHostGroup(final List hos } private ReaderFailoverResult getResultFromNextTaskBatch( + final ConnectionService connectionService, final List hosts, final ExecutorService executor, final CompletionService completionService, final int i) throws SQLException { ReaderFailoverResult result; final int numTasks = i + 1 < hosts.size() ? 2 : 1; - completionService.submit(new ConnectionAttemptTask(hosts.get(i), this.isStrictReaderRequired)); + completionService.submit( + // TODO: are there performance concerns with creating a new plugin service this often? + new ConnectionAttemptTask( + connectionService, this.getNewPluginService(), hosts.get(i), this.props, this.isStrictReaderRequired)); if (numTasks == 2) { - completionService.submit(new ConnectionAttemptTask(hosts.get(i + 1), this.isStrictReaderRequired)); + completionService.submit(new ConnectionAttemptTask(connectionService, this.getNewPluginService(), hosts.get(i + 1), this.props, this.isStrictReaderRequired)); } for (int taskNum = 0; taskNum < numTasks; taskNum++) { result = getNextResult(completionService); @@ -372,13 +389,41 @@ private ReaderFailoverResult getNextResult(final CompletionService { + private PluginService getNewPluginService() { + PartialPluginService partialPluginService = new PartialPluginService( + this.servicesContainer, + this.props, + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect() + ); + + // TODO: can we clean this up, eg move to PartialPluginService constructor? + final HostListProviderSupplier supplier = this.pluginService.getDialect().getHostListProvider(); + partialPluginService.setHostListProvider( + supplier.getProvider(this.props, this.pluginService.getOriginalUrl(), this.servicesContainer)); + + return partialPluginService; + } + private static class ConnectionAttemptTask implements Callable { + private final ConnectionService connectionService; + private final PluginService pluginService; private final HostSpec newHost; + private final Properties props; private final boolean isStrictReaderRequired; - private ConnectionAttemptTask(final HostSpec newHost, final boolean isStrictReaderRequired) { + private ConnectionAttemptTask( + final ConnectionService connectionService, + final PluginService pluginService, + final HostSpec newHost, + final Properties props, + final boolean isStrictReaderRequired) { + this.connectionService = connectionService; + this.pluginService = pluginService; this.newHost = newHost; + this.props = props; this.isStrictReaderRequired = isStrictReaderRequired; } @@ -390,21 +435,19 @@ public ReaderFailoverResult call() { LOGGER.fine( () -> Messages.get( "ClusterAwareReaderFailoverHandler.attemptingReaderConnection", - new Object[] {this.newHost.getUrl(), PropertyUtils.maskProperties(initialConnectionProps)})); + new Object[] {this.newHost.getUrl(), PropertyUtils.maskProperties(props)})); try { final Properties copy = new Properties(); - copy.putAll(initialConnectionProps); + copy.putAll(props); - // TODO: assess whether multi-threaded access to the plugin service is safe. The same plugin service is used by - // both the ConnectionWrapper and this ConnectionAttemptTask in separate threads. - final Connection conn = pluginService.forceConnect(this.newHost, copy); - pluginService.setAvailability(this.newHost.asAliases(), HostAvailability.AVAILABLE); + final Connection conn = this.connectionService.open(this.newHost, copy); + this.pluginService.setAvailability(this.newHost.asAliases(), HostAvailability.AVAILABLE); if (this.isStrictReaderRequired) { // need to ensure that new connection is a connection to a reader node try { - HostRole role = pluginService.getHostRole(conn); + HostRole role = this.pluginService.getHostRole(conn); if (!HostRole.READER.equals(role)) { LOGGER.fine( Messages.get( @@ -439,13 +482,13 @@ public ReaderFailoverResult call() { LOGGER.fine("New reader failover connection object: " + conn); return new ReaderFailoverResult(conn, this.newHost, true); } catch (final SQLException e) { - pluginService.setAvailability(newHost.asAliases(), HostAvailability.NOT_AVAILABLE); + this.pluginService.setAvailability(newHost.asAliases(), HostAvailability.NOT_AVAILABLE); LOGGER.fine( () -> Messages.get( "ClusterAwareReaderFailoverHandler.failedReaderConnection", new Object[] {this.newHost.getUrl()})); // Propagate exceptions that are not caused by network errors. - if (!pluginService.isNetworkException(e, pluginService.getTargetDriverDialect())) { + if (!this.pluginService.isNetworkException(e, this.pluginService.getTargetDriverDialect())) { return new ReaderFailoverResult( null, null, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index c684dbb50..f378d3c38 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -318,7 +318,7 @@ public void initHostProvider( initHostProviderFunc, () -> new ClusterAwareReaderFailoverHandler( - this.pluginService, + this.servicesContainer, this.properties, this.failoverTimeoutMsSetting, this.failoverReaderConnectTimeoutMsSetting, diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java index a22283c39..7eed3e5e8 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java @@ -1,407 +1,407 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atLeast; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; - -class ClusterAwareReaderFailoverHandlerTest { - - @Mock PluginService mockPluginService; - @Mock Connection mockConnection; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final List defaultHosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader2").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader3").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader4").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader5").port(1234).role(HostRole.READER).build() - ); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testFailover() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: successful connection for host at index 4 - final List hosts = defaultHosts; - final int currentHostIndex = 2; - final int successHostIndex = 4; - for (int i = 0; i < hosts.size(); i++) { - if (i != successHostIndex) { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockPluginService.forceConnect(hosts.get(i), properties)) - .thenThrow(exception); - when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); - } else { - when(mockPluginService.forceConnect(hosts.get(i), properties)).thenReturn(mockConnection); - } - } - when(mockPluginService.getTargetDriverDialect()).thenReturn(null); - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties); - final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(successHostIndex), result.getHost()); - - final HostSpec successHost = hosts.get(successHostIndex); - verify(mockPluginService, atLeast(4)).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); - verify(mockPluginService, never()) - .setAvailability(eq(successHost.asAliases()), eq(HostAvailability.NOT_AVAILABLE)); - verify(mockPluginService, times(1)) - .setAvailability(eq(successHost.asAliases()), eq(HostAvailability.AVAILABLE)); - } - - @Test - public void testFailover_timeout() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: failure to get reader since process is limited to 5s and each attempt - // to connect takes 20s - final List hosts = defaultHosts; - final int currentHostIndex = 2; - for (HostSpec host : hosts) { - when(mockPluginService.forceConnect(host, properties)) - .thenAnswer((Answer) invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - } - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties, - 5000, - 30000, - false); - - final long startTimeNano = System.nanoTime(); - final ReaderFailoverResult result = - target.failover(hosts, hosts.get(currentHostIndex)); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - @Test - public void testFailover_nullOrEmptyHostList() throws SQLException { - final ClusterAwareReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties); - final HostSpec currentHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer") - .port(1234).build(); - - ReaderFailoverResult result = target.failover(null, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - final List hosts = new ArrayList<>(); - result = target.failover(hosts, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionSuccess() throws SQLException { - // even number of connection attempts - // first connection attempt to return succeeds, second attempt cancelled - // expected test result: successful connection for host at index 2 - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - final HostSpec slowHost = hosts.get(1); - final HostSpec fastHost = hosts.get(2); - when(mockPluginService.forceConnect(slowHost, properties)) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - when(mockPluginService.forceConnect(eq(fastHost), eq(properties))).thenReturn(mockConnection); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(2), result.getHost()); - - verify(mockPluginService, never()).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); - verify(mockPluginService, times(1)) - .setAvailability(eq(fastHost.asAliases()), eq(HostAvailability.AVAILABLE)); - } - - @Test - public void testGetReader_connectionFailure() throws SQLException { - // odd number of connection attempts - // first connection attempt to return fails - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) - when(mockPluginService.forceConnect(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final int currentHostIndex = 2; - - final ReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionAttemptsTimeout() throws SQLException { - // connection attempts time out before they can succeed - // first connection attempt to return times out - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - when(mockPluginService.forceConnect(any(), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - try { - Thread.sleep(5000); - } catch (InterruptedException exception) { - // ignore - } - return mockConnection; - }); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties, - 60000, - 1000, - false); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetHostTuplesByPriority() { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ClusterAwareReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties); - final List hostsByPriority = target.getHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting a writer - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.WRITER) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testGetReaderTuplesByPriority() { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties); - final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testHostFailoverStrictReaderEnabled() { - - final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(); - final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(); - final List hosts = Arrays.asList(writer, reader); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - final ClusterAwareReaderFailoverHandler target = - new ClusterAwareReaderFailoverHandler( - mockPluginService, - properties, - DEFAULT_FAILOVER_TIMEOUT, - DEFAULT_READER_CONNECT_TIMEOUT, - true); - - // The writer is included because the original writer has likely become a reader. - List expectedHostsByPriority = Arrays.asList(reader, writer); - - List hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. - reader.setAvailability(HostAvailability.NOT_AVAILABLE); - expectedHostsByPriority = Arrays.asList(writer, reader); - - hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Writer node will only be picked if it is the only node in topology; - List expectedWriterHost = Collections.singletonList(writer); - - hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); - assertEquals(expectedWriterHost, hostsByPriority); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.atLeast; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// +// class ClusterAwareReaderFailoverHandlerTest { +// +// @Mock PluginService mockPluginService; +// @Mock Connection mockConnection; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final List defaultHosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader2").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader3").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader4").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader5").port(1234).role(HostRole.READER).build() +// ); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testFailover() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: successful connection for host at index 4 +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// final int successHostIndex = 4; +// for (int i = 0; i < hosts.size(); i++) { +// if (i != successHostIndex) { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockPluginService.forceConnect(hosts.get(i), properties)) +// .thenThrow(exception); +// when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); +// } else { +// when(mockPluginService.forceConnect(hosts.get(i), properties)).thenReturn(mockConnection); +// } +// } +// when(mockPluginService.getTargetDriverDialect()).thenReturn(null); +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties); +// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(successHostIndex), result.getHost()); +// +// final HostSpec successHost = hosts.get(successHostIndex); +// verify(mockPluginService, atLeast(4)).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); +// verify(mockPluginService, never()) +// .setAvailability(eq(successHost.asAliases()), eq(HostAvailability.NOT_AVAILABLE)); +// verify(mockPluginService, times(1)) +// .setAvailability(eq(successHost.asAliases()), eq(HostAvailability.AVAILABLE)); +// } +// +// @Test +// public void testFailover_timeout() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: failure to get reader since process is limited to 5s and each attempt +// // to connect takes 20s +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// for (HostSpec host : hosts) { +// when(mockPluginService.forceConnect(host, properties)) +// .thenAnswer((Answer) invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// } +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties, +// 5000, +// 30000, +// false); +// +// final long startTimeNano = System.nanoTime(); +// final ReaderFailoverResult result = +// target.failover(hosts, hosts.get(currentHostIndex)); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// @Test +// public void testFailover_nullOrEmptyHostList() throws SQLException { +// final ClusterAwareReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties); +// final HostSpec currentHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer") +// .port(1234).build(); +// +// ReaderFailoverResult result = target.failover(null, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// final List hosts = new ArrayList<>(); +// result = target.failover(hosts, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionSuccess() throws SQLException { +// // even number of connection attempts +// // first connection attempt to return succeeds, second attempt cancelled +// // expected test result: successful connection for host at index 2 +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// final HostSpec slowHost = hosts.get(1); +// final HostSpec fastHost = hosts.get(2); +// when(mockPluginService.forceConnect(slowHost, properties)) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// when(mockPluginService.forceConnect(eq(fastHost), eq(properties))).thenReturn(mockConnection); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(2), result.getHost()); +// +// verify(mockPluginService, never()).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); +// verify(mockPluginService, times(1)) +// .setAvailability(eq(fastHost.asAliases()), eq(HostAvailability.AVAILABLE)); +// } +// +// @Test +// public void testGetReader_connectionFailure() throws SQLException { +// // odd number of connection attempts +// // first connection attempt to return fails +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) +// when(mockPluginService.forceConnect(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final int currentHostIndex = 2; +// +// final ReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionAttemptsTimeout() throws SQLException { +// // connection attempts time out before they can succeed +// // first connection attempt to return times out +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// when(mockPluginService.forceConnect(any(), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// try { +// Thread.sleep(5000); +// } catch (InterruptedException exception) { +// // ignore +// } +// return mockConnection; +// }); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties, +// 60000, +// 1000, +// false); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetHostTuplesByPriority() { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ClusterAwareReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties); +// final List hostsByPriority = target.getHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting a writer +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.WRITER) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testGetReaderTuplesByPriority() { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties); +// final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testHostFailoverStrictReaderEnabled() { +// +// final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(); +// final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(); +// final List hosts = Arrays.asList(writer, reader); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// final ClusterAwareReaderFailoverHandler target = +// new ClusterAwareReaderFailoverHandler( +// mockPluginService, +// properties, +// DEFAULT_FAILOVER_TIMEOUT, +// DEFAULT_READER_CONNECT_TIMEOUT, +// true); +// +// // The writer is included because the original writer has likely become a reader. +// List expectedHostsByPriority = Arrays.asList(reader, writer); +// +// List hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. +// reader.setAvailability(HostAvailability.NOT_AVAILABLE); +// expectedHostsByPriority = Arrays.asList(writer, reader); +// +// hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Writer node will only be picked if it is the only node in topology; +// List expectedWriterHost = Collections.singletonList(writer); +// +// hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); +// assertEquals(expectedWriterHost, hostsByPriority); +// } +// } From 81085513096aad768f577ec8f2bb9cb7a907c6ab Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 20 Aug 2025 09:40:55 -0700 Subject: [PATCH 07/54] Add hostAvailabilityMap to WriterFailoverResult --- .../ClusterAwareWriterFailoverHandler.java | 44 +++++++++++-------- .../failover/FailoverConnectionPlugin.java | 13 ++++++ .../plugin/failover/WriterFailoverResult.java | 12 ++++- 3 files changed, 50 insertions(+), 19 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 81098ba65..711cdb355 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -19,10 +19,12 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; @@ -52,7 +54,6 @@ * same writer host, 2) try to update cluster topology and connect to a newly elected writer. */ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler { - private static final Logger LOGGER = Logger.getLogger(ClusterAwareReaderFailoverHandler.class.getName()); protected int maxFailoverTimeoutMs = 60000; // 60 sec @@ -62,8 +63,6 @@ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler protected FullServicesContainer servicesContainer; protected PluginService pluginService; protected ReaderFailoverHandler readerFailoverHandler; - private static final WriterFailoverResult DEFAULT_RESULT = - new WriterFailoverResult(false, false, null, null, "None"); public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, @@ -102,7 +101,7 @@ public WriterFailoverResult failover(final ConnectionService connectionService, throws SQLException { if (Utils.isNullOrEmpty(currentTopology)) { LOGGER.severe(() -> Messages.get("ClusterAwareWriterFailoverHandler.failoverCalledWithInvalidTopology")); - return DEFAULT_RESULT; + return new WriterFailoverResult(false, false, null, null, null, "None"); } final boolean singleTask = @@ -132,7 +131,7 @@ public WriterFailoverResult failover(final ConnectionService connectionService, } LOGGER.fine(() -> Messages.get("ClusterAwareWriterFailoverHandler.failedToConnectToWriterInstance")); - return DEFAULT_RESULT; + return new WriterFailoverResult(false, false, null, result.getHostAvailabilityMap(), null, "None"); } finally { if (!executorService.isTerminated()) { executorService.shutdownNow(); // terminate all remaining tasks @@ -160,11 +159,13 @@ private void submitTasks( final CompletionService completionService, final boolean singleTask) { final HostSpec writerHost = getWriter(currentTopology); + final Map availabilityMap = new ConcurrentHashMap<>(); if (!singleTask) { completionService.submit( new ReconnectToWriterHandler( connectionService, this.getNewPluginService(), + availabilityMap, writerHost, this.initialConnectionProps, this.reconnectWriterIntervalMs)); @@ -174,6 +175,7 @@ private void submitTasks( new WaitForNewWriterHandler( connectionService, this.getNewPluginService(), + availabilityMap, this.readerFailoverHandler, writerHost, this.initialConnectionProps, @@ -210,7 +212,7 @@ private WriterFailoverResult getNextResult( timeoutMs, TimeUnit.MILLISECONDS); if (firstCompleted == null) { // The task was unsuccessful and we have timed out - return DEFAULT_RESULT; + return new WriterFailoverResult(false, false, null, null, null, "None"); } final WriterFailoverResult result = firstCompleted.get(); if (result.isConnected()) { @@ -229,7 +231,7 @@ private WriterFailoverResult getNextResult( } catch (final ExecutionException e) { // return failure below } - return DEFAULT_RESULT; + return new WriterFailoverResult(false, false, null, null, null, "None"); } private void logTaskSuccess(final WriterFailoverResult result) { @@ -276,19 +278,22 @@ private SQLException createInterruptedException(final InterruptedException e) { private static class ReconnectToWriterHandler implements Callable { private final ConnectionService connectionService; + private final PluginService pluginService; + private final Map availabilityMap; private final HostSpec originalWriterHost; private final Properties props; private final int reconnectWriterIntervalMs; - private final PluginService pluginService; public ReconnectToWriterHandler( final ConnectionService connectionService, final PluginService pluginService, + final Map availabilityMap, final HostSpec originalWriterHost, final Properties props, final int reconnectWriterIntervalMs) { this.connectionService = connectionService; this.pluginService = pluginService; + this.availabilityMap = availabilityMap; this.originalWriterHost = originalWriterHost; this.props = props; this.reconnectWriterIntervalMs = reconnectWriterIntervalMs; @@ -321,7 +326,7 @@ public WriterFailoverResult call() { () -> Messages.get( "ClusterAwareWriterFailoverHandler.taskAEncounteredException", new Object[] {exception})); - return new WriterFailoverResult(false, false, null, null, "TaskA", exception); + return new WriterFailoverResult(false, false, null, this.availabilityMap, null, "TaskA", exception); } } @@ -332,14 +337,14 @@ public WriterFailoverResult call() { success = isCurrentHostWriter(latestTopology); LOGGER.finest("[TaskA] success: " + success); - pluginService.setAvailability(this.originalWriterHost.asAliases(), HostAvailability.AVAILABLE); - return new WriterFailoverResult(success, false, latestTopology, success ? conn : null, "TaskA"); + availabilityMap.put(this.originalWriterHost.getHost(), HostAvailability.AVAILABLE); + return new WriterFailoverResult(success, false, latestTopology, this.availabilityMap, success ? conn : null, "TaskA"); } catch (final InterruptedException exception) { Thread.currentThread().interrupt(); - return new WriterFailoverResult(success, false, latestTopology, success ? conn : null, "TaskA"); + return new WriterFailoverResult(success, false, latestTopology, this.availabilityMap, success ? conn : null, "TaskA"); } catch (final Exception ex) { - LOGGER.severe(() -> ex.getMessage()); - return new WriterFailoverResult(false, false, null, null, "TaskA"); + LOGGER.severe(ex::getMessage); + return new WriterFailoverResult(false, false, null, this.availabilityMap, null, "TaskA"); } finally { try { if (conn != null && !success && !conn.isClosed()) { @@ -370,6 +375,7 @@ private static class WaitForNewWriterHandler implements Callable availabilityMap; private final ReaderFailoverHandler readerFailoverHandler; private final HostSpec originalWriterHost; private final Properties props; @@ -382,6 +388,7 @@ private static class WaitForNewWriterHandler implements Callable availabilityMap, final ReaderFailoverHandler readerFailoverHandler, final HostSpec originalWriterHost, final Properties props, @@ -389,6 +396,7 @@ public WaitForNewWriterHandler( final List currentTopology) { this.connectionService = connectionService; this.pluginService = pluginService; + this.availabilityMap = availabilityMap; this.readerFailoverHandler = readerFailoverHandler; this.originalWriterHost = originalWriterHost; this.props = props; @@ -415,11 +423,12 @@ public WriterFailoverResult call() { true, true, this.currentTopology, + this.availabilityMap, this.currentConnection, "TaskB"); } catch (final InterruptedException exception) { Thread.currentThread().interrupt(); - return new WriterFailoverResult(false, false, null, null, "TaskB"); + return new WriterFailoverResult(false, false, null, this.availabilityMap, null, "TaskB"); } catch (final Exception ex) { LOGGER.severe( () -> Messages.get( @@ -531,11 +540,10 @@ private boolean connectToWriter(final HostSpec writerCandidate) { try { // connect to the new writer this.currentConnection = this.connectionService.open(writerCandidate, this.props); - // TODO: replace with a map that is shared between the two handlers - this.pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.AVAILABLE); + this.availabilityMap.put(writerCandidate.getHost(), HostAvailability.AVAILABLE); return true; } catch (final SQLException exception) { - this.pluginService.setAvailability(writerCandidate.asAliases(), HostAvailability.NOT_AVAILABLE); + this.availabilityMap.put(writerCandidate.getHost(), HostAvailability.NOT_AVAILABLE); return false; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index f378d3c38..2a2eb057a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -782,6 +782,19 @@ protected void failoverWriter() throws SQLException { return; } + Map hostAvailabilityMap = failoverResult.getHostAvailabilityMap(); + if (hostAvailabilityMap != null && !hostAvailabilityMap.isEmpty()) { + List allHosts = this.pluginService.getAllHosts(); + for (HostSpec host : allHosts) { + for (String alias : host.getAliases()) { + HostAvailability availability = hostAvailabilityMap.get(alias); + if (availability != null) { + host.setAvailability(availability); + } + } + } + } + this.pluginService.setCurrentConnection(failoverResult.getNewConnection(), writerHostSpec); LOGGER.fine( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java index b1fc63a33..e79f395a7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java @@ -19,7 +19,9 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.List; +import java.util.Map; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.hostavailability.HostAvailability; /** * This class holds results of Writer Failover Process. @@ -29,6 +31,7 @@ public class WriterFailoverResult { private final boolean isConnected; private final boolean isNewHost; private final List topology; + private final Map hostAvailabilityMap; private final Connection newConnection; private final String taskName; private final SQLException exception; @@ -37,21 +40,24 @@ public WriterFailoverResult( final boolean isConnected, final boolean isNewHost, final List topology, + final Map hostAvailabilityMap, final Connection newConnection, final String taskName) { - this(isConnected, isNewHost, topology, newConnection, taskName, null); + this(isConnected, isNewHost, topology, hostAvailabilityMap, newConnection, taskName, null); } public WriterFailoverResult( final boolean isConnected, final boolean isNewHost, final List topology, + final Map hostAvailabilityMap, final Connection newConnection, final String taskName, final SQLException exception) { this.isConnected = isConnected; this.isNewHost = isNewHost; this.topology = topology; + this.hostAvailabilityMap = hostAvailabilityMap; this.newConnection = newConnection; this.taskName = taskName; this.exception = exception; @@ -86,6 +92,10 @@ public List getTopology() { return this.topology; } + public Map getHostAvailabilityMap() { + return this.hostAvailabilityMap; + } + /** * Get the new connection established by the failover procedure if successful. * From 0a38ab3a27ee2d7dd00c6983fd031e2d92c6d018 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 20 Aug 2025 12:45:16 -0700 Subject: [PATCH 08/54] Add hostAvailabilityMap to ReaderFailoverResult --- .../ClusterAwareReaderFailoverHandler.java | 102 ++++++++++++------ .../failover/FailoverConnectionPlugin.java | 31 +++--- .../plugin/failover/ReaderFailoverResult.java | 18 +++- 3 files changed, 100 insertions(+), 51 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 18df01f3c..1cfb9ad67 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -21,9 +21,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; @@ -60,8 +62,6 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler Logger.getLogger(ClusterAwareReaderFailoverHandler.class.getName()); protected static final int DEFAULT_FAILOVER_TIMEOUT = 60000; // 60 sec protected static final int DEFAULT_READER_CONNECT_TIMEOUT = 30000; // 30 sec - public static final ReaderFailoverResult FAILED_READER_FAILOVER_RESULT = - new ReaderFailoverResult(null, null, false); protected Properties props; protected int maxFailoverTimeoutMs; protected int timeoutMs; @@ -135,26 +135,28 @@ public ReaderFailoverResult failover( throws SQLException { if (Utils.isNullOrEmpty(hosts)) { LOGGER.fine(() -> Messages.get("ClusterAwareReaderFailoverHandler.invalidTopology", new Object[] {"failover"})); - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, null); } + final Map availabilityMap = new ConcurrentHashMap<>(); final ExecutorService executor = ExecutorFactory.newSingleThreadExecutor("failover"); final Future future = - submitInternalFailoverTask(connectionService, hosts, currentHost, executor); - return getInternalFailoverResult(executor, future); + submitInternalFailoverTask(connectionService, hosts, currentHost, executor, availabilityMap); + return getInternalFailoverResult(executor, future, availabilityMap); } private Future submitInternalFailoverTask( final ConnectionService connectionService, final List hosts, final HostSpec currentHost, - final ExecutorService executor) { + final ExecutorService executor, + final Map availabilityMap) { final Future future = executor.submit(() -> { ReaderFailoverResult result; try { while (true) { - result = failoverInternal(connectionService, hosts, currentHost); + result = failoverInternal(connectionService, hosts, currentHost, availabilityMap); if (result != null && result.isConnected()) { return result; } @@ -162,9 +164,9 @@ private Future submitInternalFailoverTask( TimeUnit.SECONDS.sleep(1); } } catch (final SQLException ex) { - return new ReaderFailoverResult(null, null, false, ex); + return new ReaderFailoverResult(null, null, false, ex, availabilityMap); } catch (final Exception ex) { - return new ReaderFailoverResult(null, null, false, new SQLException(ex)); + return new ReaderFailoverResult(null, null, false, new SQLException(ex), availabilityMap); } }); executor.shutdown(); @@ -173,13 +175,14 @@ private Future submitInternalFailoverTask( private ReaderFailoverResult getInternalFailoverResult( final ExecutorService executor, - final Future future) throws SQLException { + final Future future, + final Map availabilityMap) throws SQLException { try { final ReaderFailoverResult result = future.get(this.maxFailoverTimeoutMs, TimeUnit.MILLISECONDS); if (result == null) { LOGGER.warning( Messages.get("ClusterAwareReaderFailoverHandler.timeout", new Object[] {this.maxFailoverTimeoutMs})); - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, availabilityMap); } return result; @@ -187,10 +190,10 @@ private ReaderFailoverResult getInternalFailoverResult( Thread.currentThread().interrupt(); throw new SQLException(Messages.get("ClusterAwareReaderFailoverHandler.interruptedThread"), "70100", e); } catch (final ExecutionException e) { - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, availabilityMap); } catch (final TimeoutException e) { future.cancel(true); - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, availabilityMap); } finally { if (!executor.isTerminated()) { executor.shutdownNow(); // terminate all remaining tasks @@ -201,13 +204,16 @@ private ReaderFailoverResult getInternalFailoverResult( protected ReaderFailoverResult failoverInternal( final ConnectionService connectionService, final List hosts, - final HostSpec currentHost) + final HostSpec currentHost, + final Map availabilityMap) throws SQLException { if (currentHost != null) { this.pluginService.setAvailability(currentHost.asAliases(), HostAvailability.NOT_AVAILABLE); + availabilityMap.put(currentHost.getHost(), HostAvailability.NOT_AVAILABLE); } + final List hostsByPriority = getHostsByPriority(hosts); - return getConnectionFromHostGroup(connectionService, hostsByPriority); + return getConnectionFromHostGroup(connectionService, hostsByPriority, availabilityMap); } public List getHostsByPriority(final List hosts) { @@ -257,11 +263,12 @@ public ReaderFailoverResult getReaderConnection( () -> Messages.get( "ClusterAwareReaderFailover.invalidTopology", new Object[] {"getReaderConnection"})); - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, null); } final List hostsByPriority = getReaderHostsByPriority(hostList); - return getConnectionFromHostGroup(connectionService, hostsByPriority); + final Map availabilityMap = new ConcurrentHashMap<>(); + return getConnectionFromHostGroup(connectionService, hostsByPriority, availabilityMap); } public List getReaderHostsByPriority(final List hosts) { @@ -303,7 +310,9 @@ public List getReaderHostsByPriority(final List hosts) { } private ReaderFailoverResult getConnectionFromHostGroup( - final ConnectionService connectionService, final List hosts) + final ConnectionService connectionService, + final List hosts, + final Map availabilityMap) throws SQLException { final ExecutorService executor = ExecutorFactory.newFixedThreadPool(2, "failover"); @@ -313,7 +322,7 @@ private ReaderFailoverResult getConnectionFromHostGroup( for (int i = 0; i < hosts.size(); i += 2) { // submit connection attempt tasks in batches of 2 final ReaderFailoverResult result = - getResultFromNextTaskBatch(connectionService, hosts, executor, completionService, i); + getResultFromNextTaskBatch(connectionService, hosts, availabilityMap, executor, completionService, i); if (result.isConnected() || result.getException() != null) { return result; } @@ -329,7 +338,8 @@ private ReaderFailoverResult getConnectionFromHostGroup( return new ReaderFailoverResult( null, null, - false); + false, + availabilityMap); } finally { executor.shutdownNow(); } @@ -338,6 +348,7 @@ private ReaderFailoverResult getConnectionFromHostGroup( private ReaderFailoverResult getResultFromNextTaskBatch( final ConnectionService connectionService, final List hosts, + final Map availabilityMap, final ExecutorService executor, final CompletionService completionService, final int i) throws SQLException { @@ -346,12 +357,25 @@ private ReaderFailoverResult getResultFromNextTaskBatch( completionService.submit( // TODO: are there performance concerns with creating a new plugin service this often? new ConnectionAttemptTask( - connectionService, this.getNewPluginService(), hosts.get(i), this.props, this.isStrictReaderRequired)); + connectionService, + this.getNewPluginService(), + availabilityMap, + hosts.get(i), + this.props, + this.isStrictReaderRequired)); if (numTasks == 2) { - completionService.submit(new ConnectionAttemptTask(connectionService, this.getNewPluginService(), hosts.get(i + 1), this.props, this.isStrictReaderRequired)); + completionService.submit( + new ConnectionAttemptTask( + connectionService, + this.getNewPluginService(), + availabilityMap, + hosts.get(i + 1), + this.props, + this.isStrictReaderRequired)); } + for (int taskNum = 0; taskNum < numTasks; taskNum++) { - result = getNextResult(completionService); + result = getNextResult(completionService, availabilityMap); if (result.isConnected()) { executor.shutdownNow(); return result; @@ -364,20 +388,24 @@ private ReaderFailoverResult getResultFromNextTaskBatch( return new ReaderFailoverResult( null, null, - false); + false, + availabilityMap); } - private ReaderFailoverResult getNextResult(final CompletionService service) + private ReaderFailoverResult getNextResult( + final CompletionService service, + final Map availabilityMap) throws SQLException { + ReaderFailoverResult failureResult = new ReaderFailoverResult(null, null, false, null, availabilityMap); try { final Future future = service.poll(this.timeoutMs, TimeUnit.MILLISECONDS); if (future == null) { - return FAILED_READER_FAILOVER_RESULT; + return failureResult; } final ReaderFailoverResult result = future.get(); - return result == null ? FAILED_READER_FAILOVER_RESULT : result; + return result == null ? failureResult : result; } catch (final ExecutionException e) { - return FAILED_READER_FAILOVER_RESULT; + return failureResult; } catch (final InterruptedException e) { Thread.currentThread().interrupt(); // "Thread was interrupted" @@ -410,6 +438,7 @@ private PluginService getNewPluginService() { private static class ConnectionAttemptTask implements Callable { private final ConnectionService connectionService; private final PluginService pluginService; + private final Map availabilityMap; private final HostSpec newHost; private final Properties props; private final boolean isStrictReaderRequired; @@ -417,11 +446,13 @@ private static class ConnectionAttemptTask implements Callable availabilityMap, final HostSpec newHost, final Properties props, final boolean isStrictReaderRequired) { this.connectionService = connectionService; this.pluginService = pluginService; + this.availabilityMap = availabilityMap; this.newHost = newHost; this.props = props; this.isStrictReaderRequired = isStrictReaderRequired; @@ -442,7 +473,7 @@ public ReaderFailoverResult call() { copy.putAll(props); final Connection conn = this.connectionService.open(this.newHost, copy); - this.pluginService.setAvailability(this.newHost.asAliases(), HostAvailability.AVAILABLE); + this.availabilityMap.put(this.newHost.getHost(), HostAvailability.AVAILABLE); if (this.isStrictReaderRequired) { // need to ensure that new connection is a connection to a reader node @@ -460,7 +491,7 @@ public ReaderFailoverResult call() { // ignore } - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, availabilityMap); } } catch (SQLException e) { LOGGER.fine(Messages.get("ClusterAwareReaderFailoverHandler.errorGettingHostRole", new Object[] {e})); @@ -471,7 +502,7 @@ public ReaderFailoverResult call() { // ignore } - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, availabilityMap); } } @@ -480,9 +511,9 @@ public ReaderFailoverResult call() { "ClusterAwareReaderFailoverHandler.successfulReaderConnection", new Object[] {this.newHost.getUrl()})); LOGGER.fine("New reader failover connection object: " + conn); - return new ReaderFailoverResult(conn, this.newHost, true); + return new ReaderFailoverResult(conn, this.newHost, true, this.availabilityMap); } catch (final SQLException e) { - this.pluginService.setAvailability(newHost.asAliases(), HostAvailability.NOT_AVAILABLE); + this.availabilityMap.put(newHost.getHost(), HostAvailability.NOT_AVAILABLE); LOGGER.fine( () -> Messages.get( "ClusterAwareReaderFailoverHandler.failedReaderConnection", @@ -493,10 +524,11 @@ public ReaderFailoverResult call() { null, null, false, - e); + e, + this.availabilityMap); } - return FAILED_READER_FAILOVER_RESULT; + return new ReaderFailoverResult(null, null, false, null, availabilityMap); } } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 2a2eb057a..2a186272f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -672,6 +672,8 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException if (exception != null) { throw exception; } + + updateHostAvailability(result.getHostAvailabilityMap()); } if (result == null || !result.isConnected()) { @@ -757,6 +759,8 @@ protected void failoverWriter() throws SQLException { if (exception != null) { throw exception; } + + updateHostAvailability(failoverResult.getHostAvailabilityMap()); } if (failoverResult == null || !failoverResult.isConnected()) { @@ -782,19 +786,6 @@ protected void failoverWriter() throws SQLException { return; } - Map hostAvailabilityMap = failoverResult.getHostAvailabilityMap(); - if (hostAvailabilityMap != null && !hostAvailabilityMap.isEmpty()) { - List allHosts = this.pluginService.getAllHosts(); - for (HostSpec host : allHosts) { - for (String alias : host.getAliases()) { - HostAvailability availability = hostAvailabilityMap.get(alias); - if (availability != null) { - host.setAvailability(availability); - } - } - } - } - this.pluginService.setCurrentConnection(failoverResult.getNewConnection(), writerHostSpec); LOGGER.fine( @@ -835,6 +826,20 @@ protected void failoverWriter() throws SQLException { } } + private void updateHostAvailability(Map hostAvailabilityMap) { + if (hostAvailabilityMap != null && !hostAvailabilityMap.isEmpty()) { + List allHosts = this.pluginService.getAllHosts(); + for (HostSpec host : allHosts) { + for (String alias : host.getAliases()) { + HostAvailability availability = hostAvailabilityMap.get(alias); + if (availability != null) { + host.setAvailability(availability); + } + } + } + } + } + protected void invalidateCurrentConnection() { final Connection conn = this.pluginService.getCurrentConnection(); if (conn == null) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java index 8d6d9630b..2ada37ca2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java @@ -18,7 +18,9 @@ import java.sql.Connection; import java.sql.SQLException; +import java.util.Map; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.hostavailability.HostAvailability; /** * This class holds results of Reader Failover Process. @@ -30,21 +32,27 @@ public class ReaderFailoverResult { private final boolean isConnected; private final SQLException exception; private final HostSpec newHost; + private final Map hostAvailabilityMap; public ReaderFailoverResult( - final Connection newConnection, final HostSpec newHost, final boolean isConnected) { - this(newConnection, newHost, isConnected, null); + final Connection newConnection, + final HostSpec newHost, + final boolean isConnected, + final Map hostAvailabilityMap) { + this(newConnection, newHost, isConnected, null, hostAvailabilityMap); } public ReaderFailoverResult( final Connection newConnection, final HostSpec newHost, final boolean isConnected, - final SQLException exception) { + final SQLException exception, + final Map hostAvailabilityMap) { this.newConnection = newConnection; this.newHost = newHost; this.isConnected = isConnected; this.exception = exception; + this.hostAvailabilityMap = hostAvailabilityMap; } /** @@ -82,4 +90,8 @@ public boolean isConnected() { public SQLException getException() { return exception; } + + public Map getHostAvailabilityMap() { + return hostAvailabilityMap; + } } From 1a7e2bcc4e3cba4bd5c63eb1103a93a9d1238f56 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 20 Aug 2025 13:33:10 -0700 Subject: [PATCH 09/54] Initialize HostListProvider in PartialPluginService constructor --- .../amazon/jdbc/PartialPluginService.java | 4 ++++ .../ClusterAwareReaderFailoverHandler.java | 22 ++++++++----------- .../ClusterAwareWriterFailoverHandler.java | 12 +++------- .../failover/FailoverConnectionPlugin.java | 9 +++----- 4 files changed, 19 insertions(+), 28 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 67027195b..c223cca57 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -36,6 +36,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.dialect.HostListProviderSupplier; import software.amazon.jdbc.exceptions.ExceptionHandler; import software.amazon.jdbc.exceptions.ExceptionManager; import software.amazon.jdbc.hostavailability.HostAvailability; @@ -123,6 +124,9 @@ public PartialPluginService( this.exceptionHandler = this.configurationProfile != null && this.configurationProfile.getExceptionHandler() != null ? this.configurationProfile.getExceptionHandler() : null; + + HostListProviderSupplier supplier = this.dbDialect.getHostListProvider(); + this.hostListProvider = supplier.getProvider(this.props, this.originalUrl, this.servicesContainer); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 1cfb9ad67..a1f917a2a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -19,6 +19,7 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -37,7 +38,6 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.HostListProviderSupplier; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; @@ -317,12 +317,15 @@ private ReaderFailoverResult getConnectionFromHostGroup( final ExecutorService executor = ExecutorFactory.newFixedThreadPool(2, "failover"); final CompletionService completionService = new ExecutorCompletionService<>(executor); + // The ConnectionAttemptTask threads should have their own plugin services since they execute concurrently and + // PluginService was not designed to be thread-safe. + List pluginServices = Arrays.asList(getNewPluginService(), getNewPluginService()); try { for (int i = 0; i < hosts.size(); i += 2) { // submit connection attempt tasks in batches of 2 final ReaderFailoverResult result = - getResultFromNextTaskBatch(connectionService, hosts, availabilityMap, executor, completionService, i); + getResultFromNextTaskBatch(connectionService, hosts, availabilityMap, executor, completionService, pluginServices, i); if (result.isConnected() || result.getException() != null) { return result; } @@ -351,14 +354,14 @@ private ReaderFailoverResult getResultFromNextTaskBatch( final Map availabilityMap, final ExecutorService executor, final CompletionService completionService, + final List pluginServices, final int i) throws SQLException { ReaderFailoverResult result; final int numTasks = i + 1 < hosts.size() ? 2 : 1; completionService.submit( - // TODO: are there performance concerns with creating a new plugin service this often? new ConnectionAttemptTask( connectionService, - this.getNewPluginService(), + pluginServices.get(0), availabilityMap, hosts.get(i), this.props, @@ -367,7 +370,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( completionService.submit( new ConnectionAttemptTask( connectionService, - this.getNewPluginService(), + pluginServices.get(1), availabilityMap, hosts.get(i + 1), this.props, @@ -418,7 +421,7 @@ private ReaderFailoverResult getNextResult( } private PluginService getNewPluginService() { - PartialPluginService partialPluginService = new PartialPluginService( + return new PartialPluginService( this.servicesContainer, this.props, this.pluginService.getOriginalUrl(), @@ -426,13 +429,6 @@ private PluginService getNewPluginService() { this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect() ); - - // TODO: can we clean this up, eg move to PartialPluginService constructor? - final HostListProviderSupplier supplier = this.pluginService.getDialect().getHostListProvider(); - partialPluginService.setHostListProvider( - supplier.getProvider(this.props, this.pluginService.getOriginalUrl(), this.servicesContainer)); - - return partialPluginService; } private static class ConnectionAttemptTask implements Callable { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 711cdb355..76703fbb7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -36,7 +36,6 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.HostListProviderSupplier; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; @@ -186,7 +185,9 @@ private void submitTasks( } private PluginService getNewPluginService() { - PartialPluginService partialPluginService = new PartialPluginService( + // Each task should get its own PluginService since they execute concurrently and PluginService was not designed to + // be thread-safe. + return new PartialPluginService( this.servicesContainer, this.initialConnectionProps, this.pluginService.getOriginalUrl(), @@ -194,13 +195,6 @@ private PluginService getNewPluginService() { this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect() ); - - // TODO: can we clean this up, eg move to PartialPluginService constructor? - final HostListProviderSupplier supplier = this.pluginService.getDialect().getHostListProvider(); - partialPluginService.setHostListProvider( - supplier.getProvider(this.initialConnectionProps, this.pluginService.getOriginalUrl(), this.servicesContainer)); - - return partialPluginService; } private WriterFailoverResult getNextResult( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 2a186272f..7e365ad93 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -619,13 +619,10 @@ protected void dealWithIllegalStateException( protected void failover(final HostSpec failedHost) throws SQLException { this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); - - // TODO: After failover, retrieve the unavailable hosts from the handlers after replacing - // pluginService#setAvailability - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); if (this.connectionService == null) { + TargetDriverHelper helper = new TargetDriverHelper(); + java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); + final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); this.connectionService = new ConnectionServiceImpl( servicesContainer.getStorageService(), servicesContainer.getMonitorService(), From a965c3f96d4ed27570c694a09381839b49ab8fb8 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Thu, 21 Aug 2025 14:22:17 -0700 Subject: [PATCH 10/54] Pass ConnectionService as constructor arg to failover handlers --- .../amazon/jdbc/PartialPluginService.java | 4 ++ .../ClusterAwareReaderFailoverHandler.java | 29 +++++---- .../ClusterAwareWriterFailoverHandler.java | 18 +++--- .../failover/FailoverConnectionPlugin.java | 59 +++++++++---------- .../failover/ReaderFailoverHandler.java | 7 +-- .../failover/WriterFailoverHandler.java | 4 +- .../connection/ConnectionServiceImpl.java | 4 -- 7 files changed, 61 insertions(+), 64 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index c223cca57..880c6f0dc 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -125,6 +125,10 @@ public PartialPluginService( ? this.configurationProfile.getExceptionHandler() : null; + servicesContainer.setHostListProviderService(this); + servicesContainer.setPluginService(this); + servicesContainer.setPluginManagerService(this); + HostListProviderSupplier supplier = this.dbDialect.getHostListProvider(); this.hostListProvider = supplier.getProvider(this.props, this.originalUrl, this.servicesContainer); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index a1f917a2a..18101b26d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -67,6 +67,7 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler protected int timeoutMs; protected boolean isStrictReaderRequired; protected final FullServicesContainer servicesContainer; + protected final ConnectionService connectionService; protected final PluginService pluginService; /** @@ -77,9 +78,11 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler */ public ClusterAwareReaderFailoverHandler( final FullServicesContainer servicesContainer, + final ConnectionService connectionService, final Properties props) { this( servicesContainer, + connectionService, props, DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, @@ -98,11 +101,13 @@ public ClusterAwareReaderFailoverHandler( */ public ClusterAwareReaderFailoverHandler( final FullServicesContainer servicesContainer, + final ConnectionService connectionService, final Properties props, final int maxFailoverTimeoutMs, final int timeoutMs, final boolean isStrictReaderRequired) { this.servicesContainer = servicesContainer; + this.connectionService = connectionService; this.pluginService = servicesContainer.getPluginService(); this.props = props; this.maxFailoverTimeoutMs = maxFailoverTimeoutMs; @@ -130,8 +135,7 @@ protected void setTimeoutMs(final int timeoutMs) { * @return {@link ReaderFailoverResult} The results of this process. */ @Override - public ReaderFailoverResult failover( - final ConnectionService connectionService, final List hosts, final HostSpec currentHost) + public ReaderFailoverResult failover(final List hosts, final HostSpec currentHost) throws SQLException { if (Utils.isNullOrEmpty(hosts)) { LOGGER.fine(() -> Messages.get("ClusterAwareReaderFailoverHandler.invalidTopology", new Object[] {"failover"})); @@ -142,12 +146,11 @@ public ReaderFailoverResult failover( final ExecutorService executor = ExecutorFactory.newSingleThreadExecutor("failover"); final Future future = - submitInternalFailoverTask(connectionService, hosts, currentHost, executor, availabilityMap); + submitInternalFailoverTask(hosts, currentHost, executor, availabilityMap); return getInternalFailoverResult(executor, future, availabilityMap); } private Future submitInternalFailoverTask( - final ConnectionService connectionService, final List hosts, final HostSpec currentHost, final ExecutorService executor, @@ -156,7 +159,7 @@ private Future submitInternalFailoverTask( ReaderFailoverResult result; try { while (true) { - result = failoverInternal(connectionService, hosts, currentHost, availabilityMap); + result = failoverInternal(hosts, currentHost, availabilityMap); if (result != null && result.isConnected()) { return result; } @@ -202,7 +205,6 @@ private ReaderFailoverResult getInternalFailoverResult( } protected ReaderFailoverResult failoverInternal( - final ConnectionService connectionService, final List hosts, final HostSpec currentHost, final Map availabilityMap) @@ -213,7 +215,7 @@ protected ReaderFailoverResult failoverInternal( } final List hostsByPriority = getHostsByPriority(hosts); - return getConnectionFromHostGroup(connectionService, hostsByPriority, availabilityMap); + return getConnectionFromHostGroup(hostsByPriority, availabilityMap); } public List getHostsByPriority(final List hosts) { @@ -255,8 +257,7 @@ public List getHostsByPriority(final List hosts) { * @return {@link ReaderFailoverResult} The results of this process. */ @Override - public ReaderFailoverResult getReaderConnection( - final ConnectionService connectionService, final List hostList) + public ReaderFailoverResult getReaderConnection(final List hostList) throws SQLException { if (Utils.isNullOrEmpty(hostList)) { LOGGER.fine( @@ -268,7 +269,7 @@ public ReaderFailoverResult getReaderConnection( final List hostsByPriority = getReaderHostsByPriority(hostList); final Map availabilityMap = new ConcurrentHashMap<>(); - return getConnectionFromHostGroup(connectionService, hostsByPriority, availabilityMap); + return getConnectionFromHostGroup(hostsByPriority, availabilityMap); } public List getReaderHostsByPriority(final List hosts) { @@ -310,7 +311,6 @@ public List getReaderHostsByPriority(final List hosts) { } private ReaderFailoverResult getConnectionFromHostGroup( - final ConnectionService connectionService, final List hosts, final Map availabilityMap) throws SQLException { @@ -325,7 +325,7 @@ private ReaderFailoverResult getConnectionFromHostGroup( for (int i = 0; i < hosts.size(); i += 2) { // submit connection attempt tasks in batches of 2 final ReaderFailoverResult result = - getResultFromNextTaskBatch(connectionService, hosts, availabilityMap, executor, completionService, pluginServices, i); + getResultFromNextTaskBatch(hosts, availabilityMap, executor, completionService, pluginServices, i); if (result.isConnected() || result.getException() != null) { return result; } @@ -349,7 +349,6 @@ private ReaderFailoverResult getConnectionFromHostGroup( } private ReaderFailoverResult getResultFromNextTaskBatch( - final ConnectionService connectionService, final List hosts, final Map availabilityMap, final ExecutorService executor, @@ -360,7 +359,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( final int numTasks = i + 1 < hosts.size() ? 2 : 1; completionService.submit( new ConnectionAttemptTask( - connectionService, + this.connectionService, pluginServices.get(0), availabilityMap, hosts.get(i), @@ -369,7 +368,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( if (numTasks == 2) { completionService.submit( new ConnectionAttemptTask( - connectionService, + this.connectionService, pluginServices.get(1), availabilityMap, hosts.get(i + 1), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 76703fbb7..07894bf63 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -60,14 +60,17 @@ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler protected int reconnectWriterIntervalMs = 5000; // 5 sec protected Properties initialConnectionProps; protected FullServicesContainer servicesContainer; + protected ConnectionService connectionService; protected PluginService pluginService; protected ReaderFailoverHandler readerFailoverHandler; public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, + final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps) { this.servicesContainer = servicesContainer; + this.connectionService = connectionService; this.pluginService = servicesContainer.getPluginService(); this.readerFailoverHandler = readerFailoverHandler; this.initialConnectionProps = initialConnectionProps; @@ -75,6 +78,7 @@ public ClusterAwareWriterFailoverHandler( public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, + final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps, final int failoverTimeoutMs, @@ -82,6 +86,7 @@ public ClusterAwareWriterFailoverHandler( final int reconnectWriterIntervalMs) { this( servicesContainer, + connectionService, readerFailoverHandler, initialConnectionProps); this.maxFailoverTimeoutMs = failoverTimeoutMs; @@ -96,7 +101,7 @@ public ClusterAwareWriterFailoverHandler( * @return {@link WriterFailoverResult} The results of this process. */ @Override - public WriterFailoverResult failover(final ConnectionService connectionService, final List currentTopology) + public WriterFailoverResult failover(final List currentTopology) throws SQLException { if (Utils.isNullOrEmpty(currentTopology)) { LOGGER.severe(() -> Messages.get("ClusterAwareWriterFailoverHandler.failoverCalledWithInvalidTopology")); @@ -109,7 +114,7 @@ public WriterFailoverResult failover(final ConnectionService connectionService, final ExecutorService executorService = ExecutorFactory.newFixedThreadPool(2, "failover"); final CompletionService completionService = new ExecutorCompletionService<>(executorService); - submitTasks(connectionService, currentTopology, executorService, completionService, singleTask); + submitTasks(currentTopology, executorService, completionService, singleTask); try { final long startTimeNano = System.nanoTime(); @@ -152,7 +157,6 @@ private static HostSpec getWriter(final List topology) { } private void submitTasks( - final ConnectionService connectionService, final List currentTopology, final ExecutorService executorService, final CompletionService completionService, @@ -162,7 +166,7 @@ private void submitTasks( if (!singleTask) { completionService.submit( new ReconnectToWriterHandler( - connectionService, + this.connectionService, this.getNewPluginService(), availabilityMap, writerHost, @@ -172,7 +176,7 @@ private void submitTasks( completionService.submit( new WaitForNewWriterHandler( - connectionService, + this.connectionService, this.getNewPluginService(), availabilityMap, this.readerFailoverHandler, @@ -310,7 +314,7 @@ public WriterFailoverResult call() { conn.close(); } - conn = connectionService.open(this.originalWriterHost, this.props); + conn = this.connectionService.open(this.originalWriterHost, this.props); this.pluginService.forceRefreshHostList(conn); latestTopology = this.pluginService.getAllHosts(); } catch (final SQLException exception) { @@ -439,7 +443,7 @@ private void connectToReader() throws InterruptedException { while (true) { try { final ReaderFailoverResult connResult = - readerFailoverHandler.getReaderConnection(this.connectionService, this.currentTopology); + readerFailoverHandler.getReaderConnection(this.currentTopology); if (isValidReaderConnection(connResult)) { this.currentReaderConnection = connResult.getConnection(); this.currentReaderHost = connResult.getHost(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 7e365ad93..a423a27c9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -28,7 +28,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; @@ -122,8 +122,8 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { private RdsUrlType rdsUrlType = null; private HostListProviderService hostListProviderService; private final AuroraStaleDnsHelper staleDnsHelper; - private Supplier writerFailoverHandlerSupplier; - private Supplier readerFailoverHandlerSupplier; + private Function writerFailoverHandlerSupplier; + private Function readerFailoverHandlerSupplier; public static final AwsWrapperProperty FAILOVER_CLUSTER_TOPOLOGY_REFRESH_RATE_MS = new AwsWrapperProperty( @@ -316,16 +316,18 @@ public void initHostProvider( initHostProvider( hostListProviderService, initHostProviderFunc, - () -> + (connectionService) -> new ClusterAwareReaderFailoverHandler( this.servicesContainer, + connectionService, this.properties, this.failoverTimeoutMsSetting, this.failoverReaderConnectTimeoutMsSetting, this.failoverMode == FailoverMode.STRICT_READER), - () -> + (connectionService) -> new ClusterAwareWriterFailoverHandler( this.servicesContainer, + connectionService, this.readerFailoverHandler, this.properties, this.failoverTimeoutMsSetting, @@ -336,8 +338,8 @@ public void initHostProvider( void initHostProvider( final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc, - final Supplier readerFailoverHandlerSupplier, - final Supplier writerFailoverHandlerSupplier) + final Function readerFailoverHandlerSupplier, + final Function writerFailoverHandlerSupplier) throws SQLException { this.readerFailoverHandlerSupplier = readerFailoverHandlerSupplier; this.writerFailoverHandlerSupplier = writerFailoverHandlerSupplier; @@ -618,24 +620,6 @@ protected void dealWithIllegalStateException( */ protected void failover(final HostSpec failedHost) throws SQLException { this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); - - if (this.connectionService == null) { - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); - this.connectionService = new ConnectionServiceImpl( - servicesContainer.getStorageService(), - servicesContainer.getMonitorService(), - servicesContainer.getTelemetryFactory(), - defaultConnectionProvider, - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - properties - ); - } - if (this.failoverMode == FailoverMode.STRICT_WRITER) { failoverWriter(); } else { @@ -662,8 +646,7 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException failedHost = failedHostSpec; } - final ReaderFailoverResult result = readerFailoverHandler.failover( - this.connectionService, this.pluginService.getHosts(), failedHost); + final ReaderFailoverResult result = readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); if (result != null) { final SQLException exception = result.getException(); if (exception != null) { @@ -750,7 +733,7 @@ protected void failoverWriter() throws SQLException { try { LOGGER.info(() -> Messages.get("Failover.startWriterFailover")); final WriterFailoverResult failoverResult = - this.writerFailoverHandler.failover(this.connectionService, this.pluginService.getAllHosts()); + this.writerFailoverHandler.failover(this.pluginService.getAllHosts()); if (failoverResult != null) { final SQLException exception = failoverResult.getException(); if (exception != null) { @@ -925,20 +908,36 @@ public Connection connect( final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { + if (this.connectionService == null) { + TargetDriverHelper helper = new TargetDriverHelper(); + java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); + final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); + this.connectionService = new ConnectionServiceImpl( + servicesContainer.getStorageService(), + servicesContainer.getMonitorService(), + servicesContainer.getTelemetryFactory(), + defaultConnectionProvider, + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect(), + properties + ); + } this.initFailoverMode(); if (this.readerFailoverHandler == null) { if (this.readerFailoverHandlerSupplier == null) { throw new SQLException(Messages.get("Failover.nullReaderFailoverHandlerSupplier")); } - this.readerFailoverHandler = this.readerFailoverHandlerSupplier.get(); + this.readerFailoverHandler = this.readerFailoverHandlerSupplier.apply(this.connectionService); } if (this.writerFailoverHandler == null) { if (this.writerFailoverHandlerSupplier == null) { throw new SQLException(Messages.get("Failover.nullWriterFailoverHandlerSupplier")); } - this.writerFailoverHandler = this.writerFailoverHandlerSupplier.get(); + this.writerFailoverHandler = this.writerFailoverHandlerSupplier.apply(this.connectionService); } Connection conn = null; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java index 008c1e17e..e006558b6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java @@ -19,7 +19,6 @@ import java.sql.SQLException; import java.util.List; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.util.connection.ConnectionService; /** * Interface for Reader Failover Process handler. This handler implements all necessary logic to try @@ -37,8 +36,7 @@ public interface ReaderFailoverHandler { * @return {@link ReaderFailoverResult} The results of this process. * @throws SQLException indicating whether the failover attempt was successful. */ - ReaderFailoverResult failover( - ConnectionService connectionService, List hosts, HostSpec currentHost) throws SQLException; + ReaderFailoverResult failover(List hosts, HostSpec currentHost) throws SQLException; /** * Called to get any available reader connection. If no reader is available then result of process @@ -48,6 +46,5 @@ ReaderFailoverResult failover( * @return {@link ReaderFailoverResult} The results of this process. * @throws SQLException if any error occurred while attempting a reader connection. */ - ReaderFailoverResult getReaderConnection( - ConnectionService connectionService, List hostList) throws SQLException; + ReaderFailoverResult getReaderConnection(List hostList) throws SQLException; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java index 68e69259b..caae8de8c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java @@ -19,7 +19,6 @@ import java.sql.SQLException; import java.util.List; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.util.connection.ConnectionService; /** * Interface for Writer Failover Process handler. This handler implements all necessary logic to try @@ -34,6 +33,5 @@ public interface WriterFailoverHandler { * @return {@link WriterFailoverResult} The results of this process. * @throws SQLException indicating whether the failover attempt was successful. */ - WriterFailoverResult failover( - ConnectionService connectionService, List currentTopology) throws SQLException; + WriterFailoverResult failover(List currentTopology) throws SQLException; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 73a68fdcb..0e506a112 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -68,10 +68,6 @@ public ConnectionServiceImpl( ); this.pluginService = partialPluginService; - servicesContainer.setHostListProviderService(partialPluginService); - servicesContainer.setPluginService(partialPluginService); - servicesContainer.setPluginManagerService(partialPluginService); - this.pluginManager.init(servicesContainer, props, partialPluginService, null); } From a3d8a5e1cdd51d5aadf77e09ed829490d787f9a1 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 22 Aug 2025 10:39:13 -0700 Subject: [PATCH 11/54] PR suggestions --- .../ClusterAwareReaderFailoverHandler.java | 120 +++++++----------- .../ClusterAwareWriterFailoverHandler.java | 52 ++++---- .../failover/FailoverConnectionPlugin.java | 87 +++++++------ .../failover/ReaderFailoverHandler.java | 10 ++ .../plugin/failover/ReaderFailoverResult.java | 16 +-- .../failover/WriterFailoverHandler.java | 10 ++ .../plugin/failover/WriterFailoverResult.java | 12 +- 7 files changed, 140 insertions(+), 167 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 18101b26d..4cb587f55 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -60,15 +60,18 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler private static final Logger LOGGER = Logger.getLogger(ClusterAwareReaderFailoverHandler.class.getName()); + protected static final ReaderFailoverResult FAILED_READER_FAILOVER_RESULT = + new ReaderFailoverResult(null, null, false); protected static final int DEFAULT_FAILOVER_TIMEOUT = 60000; // 60 sec protected static final int DEFAULT_READER_CONNECT_TIMEOUT = 30000; // 30 sec + protected final Map hostAvailabilityMap = new ConcurrentHashMap<>(); + protected final FullServicesContainer servicesContainer; + protected final ConnectionService connectionService; + protected final PluginService pluginService; protected Properties props; protected int maxFailoverTimeoutMs; protected int timeoutMs; protected boolean isStrictReaderRequired; - protected final FullServicesContainer servicesContainer; - protected final ConnectionService connectionService; - protected final PluginService pluginService; /** * ClusterAwareReaderFailoverHandler constructor. @@ -115,6 +118,11 @@ public ClusterAwareReaderFailoverHandler( this.isStrictReaderRequired = isStrictReaderRequired; } + @Override + public Map getHostAvailabilityMap() { + return this.hostAvailabilityMap; + } + /** * Set process timeout in millis. Entire process of connecting to a reader will be limited by this * time duration. @@ -125,41 +133,29 @@ protected void setTimeoutMs(final int timeoutMs) { this.timeoutMs = timeoutMs; } - /** - * Called to start Reader Failover Process. This process tries to connect to any reader. If no - * reader is available then driver may also try to connect to a writer host, down hosts, and the - * current reader host. - * - * @param hosts Cluster current topology - * @param currentHost The currently connected host that has failed. - * @return {@link ReaderFailoverResult} The results of this process. - */ @Override public ReaderFailoverResult failover(final List hosts, final HostSpec currentHost) throws SQLException { if (Utils.isNullOrEmpty(hosts)) { LOGGER.fine(() -> Messages.get("ClusterAwareReaderFailoverHandler.invalidTopology", new Object[] {"failover"})); - return new ReaderFailoverResult(null, null, false, null, null); + return FAILED_READER_FAILOVER_RESULT; } - final Map availabilityMap = new ConcurrentHashMap<>(); final ExecutorService executor = ExecutorFactory.newSingleThreadExecutor("failover"); - final Future future = - submitInternalFailoverTask(hosts, currentHost, executor, availabilityMap); - return getInternalFailoverResult(executor, future, availabilityMap); + final Future future = submitInternalFailoverTask(hosts, currentHost, executor); + return getInternalFailoverResult(executor, future); } private Future submitInternalFailoverTask( final List hosts, final HostSpec currentHost, - final ExecutorService executor, - final Map availabilityMap) { + final ExecutorService executor) { final Future future = executor.submit(() -> { ReaderFailoverResult result; try { while (true) { - result = failoverInternal(hosts, currentHost, availabilityMap); + result = failoverInternal(hosts, currentHost); if (result != null && result.isConnected()) { return result; } @@ -167,9 +163,9 @@ private Future submitInternalFailoverTask( TimeUnit.SECONDS.sleep(1); } } catch (final SQLException ex) { - return new ReaderFailoverResult(null, null, false, ex, availabilityMap); + return new ReaderFailoverResult(null, null, false, ex); } catch (final Exception ex) { - return new ReaderFailoverResult(null, null, false, new SQLException(ex), availabilityMap); + return new ReaderFailoverResult(null, null, false, new SQLException(ex)); } }); executor.shutdown(); @@ -178,14 +174,13 @@ private Future submitInternalFailoverTask( private ReaderFailoverResult getInternalFailoverResult( final ExecutorService executor, - final Future future, - final Map availabilityMap) throws SQLException { + final Future future) throws SQLException { try { final ReaderFailoverResult result = future.get(this.maxFailoverTimeoutMs, TimeUnit.MILLISECONDS); if (result == null) { LOGGER.warning( Messages.get("ClusterAwareReaderFailoverHandler.timeout", new Object[] {this.maxFailoverTimeoutMs})); - return new ReaderFailoverResult(null, null, false, null, availabilityMap); + return FAILED_READER_FAILOVER_RESULT; } return result; @@ -193,10 +188,10 @@ private ReaderFailoverResult getInternalFailoverResult( Thread.currentThread().interrupt(); throw new SQLException(Messages.get("ClusterAwareReaderFailoverHandler.interruptedThread"), "70100", e); } catch (final ExecutionException e) { - return new ReaderFailoverResult(null, null, false, null, availabilityMap); + return FAILED_READER_FAILOVER_RESULT; } catch (final TimeoutException e) { future.cancel(true); - return new ReaderFailoverResult(null, null, false, null, availabilityMap); + return FAILED_READER_FAILOVER_RESULT; } finally { if (!executor.isTerminated()) { executor.shutdownNow(); // terminate all remaining tasks @@ -204,18 +199,15 @@ private ReaderFailoverResult getInternalFailoverResult( } } - protected ReaderFailoverResult failoverInternal( - final List hosts, - final HostSpec currentHost, - final Map availabilityMap) + protected ReaderFailoverResult failoverInternal(final List hosts, final HostSpec currentHost) throws SQLException { if (currentHost != null) { this.pluginService.setAvailability(currentHost.asAliases(), HostAvailability.NOT_AVAILABLE); - availabilityMap.put(currentHost.getHost(), HostAvailability.NOT_AVAILABLE); + this.hostAvailabilityMap.put(currentHost.getHost(), HostAvailability.NOT_AVAILABLE); } final List hostsByPriority = getHostsByPriority(hosts); - return getConnectionFromHostGroup(hostsByPriority, availabilityMap); + return getConnectionFromHostGroup(hostsByPriority); } public List getHostsByPriority(final List hosts) { @@ -264,12 +256,11 @@ public ReaderFailoverResult getReaderConnection(final List hostList) () -> Messages.get( "ClusterAwareReaderFailover.invalidTopology", new Object[] {"getReaderConnection"})); - return new ReaderFailoverResult(null, null, false, null, null); + return FAILED_READER_FAILOVER_RESULT; } final List hostsByPriority = getReaderHostsByPriority(hostList); - final Map availabilityMap = new ConcurrentHashMap<>(); - return getConnectionFromHostGroup(hostsByPriority, availabilityMap); + return getConnectionFromHostGroup(hostsByPriority); } public List getReaderHostsByPriority(final List hosts) { @@ -291,11 +282,10 @@ public List getReaderHostsByPriority(final List hosts) { Collections.shuffle(activeReaders); Collections.shuffle(downHostList); - final List hostsByPriority = new ArrayList<>(); - hostsByPriority.addAll(activeReaders); + final List hostsByPriority = new ArrayList<>(activeReaders); + final int numOfReaders = activeReaders.size() + downHostList.size(); hostsByPriority.addAll(downHostList); - final int numOfReaders = activeReaders.size() + downHostList.size(); if (writerHost == null) { return hostsByPriority; } @@ -310,10 +300,7 @@ public List getReaderHostsByPriority(final List hosts) { return hostsByPriority; } - private ReaderFailoverResult getConnectionFromHostGroup( - final List hosts, - final Map availabilityMap) - throws SQLException { + private ReaderFailoverResult getConnectionFromHostGroup(final List hosts) throws SQLException { final ExecutorService executor = ExecutorFactory.newFixedThreadPool(2, "failover"); final CompletionService completionService = new ExecutorCompletionService<>(executor); @@ -325,7 +312,7 @@ private ReaderFailoverResult getConnectionFromHostGroup( for (int i = 0; i < hosts.size(); i += 2) { // submit connection attempt tasks in batches of 2 final ReaderFailoverResult result = - getResultFromNextTaskBatch(hosts, availabilityMap, executor, completionService, pluginServices, i); + getResultFromNextTaskBatch(hosts, executor, completionService, pluginServices, i); if (result.isConnected() || result.getException() != null) { return result; } @@ -338,11 +325,7 @@ private ReaderFailoverResult getConnectionFromHostGroup( } } - return new ReaderFailoverResult( - null, - null, - false, - availabilityMap); + return new ReaderFailoverResult(null, null, false); } finally { executor.shutdownNow(); } @@ -350,7 +333,6 @@ private ReaderFailoverResult getConnectionFromHostGroup( private ReaderFailoverResult getResultFromNextTaskBatch( final List hosts, - final Map availabilityMap, final ExecutorService executor, final CompletionService completionService, final List pluginServices, @@ -361,7 +343,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( new ConnectionAttemptTask( this.connectionService, pluginServices.get(0), - availabilityMap, + this.hostAvailabilityMap, hosts.get(i), this.props, this.isStrictReaderRequired)); @@ -370,14 +352,14 @@ private ReaderFailoverResult getResultFromNextTaskBatch( new ConnectionAttemptTask( this.connectionService, pluginServices.get(1), - availabilityMap, + this.hostAvailabilityMap, hosts.get(i + 1), this.props, this.isStrictReaderRequired)); } for (int taskNum = 0; taskNum < numTasks; taskNum++) { - result = getNextResult(completionService, availabilityMap); + result = getNextResult(completionService); if (result.isConnected()) { executor.shutdownNow(); return result; @@ -387,27 +369,20 @@ private ReaderFailoverResult getResultFromNextTaskBatch( return result; } } - return new ReaderFailoverResult( - null, - null, - false, - availabilityMap); + return new ReaderFailoverResult(null, null, false); } - private ReaderFailoverResult getNextResult( - final CompletionService service, - final Map availabilityMap) + private ReaderFailoverResult getNextResult(final CompletionService service) throws SQLException { - ReaderFailoverResult failureResult = new ReaderFailoverResult(null, null, false, null, availabilityMap); try { final Future future = service.poll(this.timeoutMs, TimeUnit.MILLISECONDS); if (future == null) { - return failureResult; + return FAILED_READER_FAILOVER_RESULT; } final ReaderFailoverResult result = future.get(); - return result == null ? failureResult : result; + return result == null ? FAILED_READER_FAILOVER_RESULT : result; } catch (final ExecutionException e) { - return failureResult; + return FAILED_READER_FAILOVER_RESULT; } catch (final InterruptedException e) { Thread.currentThread().interrupt(); // "Thread was interrupted" @@ -486,7 +461,7 @@ public ReaderFailoverResult call() { // ignore } - return new ReaderFailoverResult(null, null, false, null, availabilityMap); + return FAILED_READER_FAILOVER_RESULT; } } catch (SQLException e) { LOGGER.fine(Messages.get("ClusterAwareReaderFailoverHandler.errorGettingHostRole", new Object[] {e})); @@ -497,7 +472,7 @@ public ReaderFailoverResult call() { // ignore } - return new ReaderFailoverResult(null, null, false, null, availabilityMap); + return FAILED_READER_FAILOVER_RESULT; } } @@ -506,7 +481,7 @@ public ReaderFailoverResult call() { "ClusterAwareReaderFailoverHandler.successfulReaderConnection", new Object[] {this.newHost.getUrl()})); LOGGER.fine("New reader failover connection object: " + conn); - return new ReaderFailoverResult(conn, this.newHost, true, this.availabilityMap); + return new ReaderFailoverResult(conn, this.newHost, true); } catch (final SQLException e) { this.availabilityMap.put(newHost.getHost(), HostAvailability.NOT_AVAILABLE); LOGGER.fine( @@ -515,15 +490,10 @@ public ReaderFailoverResult call() { new Object[] {this.newHost.getUrl()})); // Propagate exceptions that are not caused by network errors. if (!this.pluginService.isNetworkException(e, this.pluginService.getTargetDriverDialect())) { - return new ReaderFailoverResult( - null, - null, - false, - e, - this.availabilityMap); + return new ReaderFailoverResult(null, null, false, e); } - return new ReaderFailoverResult(null, null, false, null, availabilityMap); + return FAILED_READER_FAILOVER_RESULT; } } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 07894bf63..8533b1055 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -54,15 +54,18 @@ */ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler { private static final Logger LOGGER = Logger.getLogger(ClusterAwareReaderFailoverHandler.class.getName()); - + protected static final WriterFailoverResult DEFAULT_RESULT = + new WriterFailoverResult(false, false, null, null, "None"); + + protected final Properties initialConnectionProps; + protected final FullServicesContainer servicesContainer; + protected final ConnectionService connectionService; + protected final PluginService pluginService; + protected final ReaderFailoverHandler readerFailoverHandler; + protected final Map hostAvailabilityMap = new ConcurrentHashMap<>(); protected int maxFailoverTimeoutMs = 60000; // 60 sec protected int readTopologyIntervalMs = 5000; // 5 sec protected int reconnectWriterIntervalMs = 5000; // 5 sec - protected Properties initialConnectionProps; - protected FullServicesContainer servicesContainer; - protected ConnectionService connectionService; - protected PluginService pluginService; - protected ReaderFailoverHandler readerFailoverHandler; public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, @@ -94,18 +97,17 @@ public ClusterAwareWriterFailoverHandler( this.reconnectWriterIntervalMs = reconnectWriterIntervalMs; } - /** - * Called to start Writer Failover Process. - * - * @param currentTopology Cluster current topology - * @return {@link WriterFailoverResult} The results of this process. - */ + @Override + public Map getHostAvailabilityMap() { + return this.hostAvailabilityMap; + } + @Override public WriterFailoverResult failover(final List currentTopology) throws SQLException { if (Utils.isNullOrEmpty(currentTopology)) { LOGGER.severe(() -> Messages.get("ClusterAwareWriterFailoverHandler.failoverCalledWithInvalidTopology")); - return new WriterFailoverResult(false, false, null, null, null, "None"); + return DEFAULT_RESULT; } final boolean singleTask = @@ -135,7 +137,7 @@ public WriterFailoverResult failover(final List currentTopology) } LOGGER.fine(() -> Messages.get("ClusterAwareWriterFailoverHandler.failedToConnectToWriterInstance")); - return new WriterFailoverResult(false, false, null, result.getHostAvailabilityMap(), null, "None"); + return DEFAULT_RESULT; } finally { if (!executorService.isTerminated()) { executorService.shutdownNow(); // terminate all remaining tasks @@ -162,13 +164,12 @@ private void submitTasks( final CompletionService completionService, final boolean singleTask) { final HostSpec writerHost = getWriter(currentTopology); - final Map availabilityMap = new ConcurrentHashMap<>(); if (!singleTask) { completionService.submit( new ReconnectToWriterHandler( this.connectionService, this.getNewPluginService(), - availabilityMap, + this.hostAvailabilityMap, writerHost, this.initialConnectionProps, this.reconnectWriterIntervalMs)); @@ -178,7 +179,7 @@ private void submitTasks( new WaitForNewWriterHandler( this.connectionService, this.getNewPluginService(), - availabilityMap, + this.hostAvailabilityMap, this.readerFailoverHandler, writerHost, this.initialConnectionProps, @@ -210,7 +211,7 @@ private WriterFailoverResult getNextResult( timeoutMs, TimeUnit.MILLISECONDS); if (firstCompleted == null) { // The task was unsuccessful and we have timed out - return new WriterFailoverResult(false, false, null, null, null, "None"); + return DEFAULT_RESULT; } final WriterFailoverResult result = firstCompleted.get(); if (result.isConnected()) { @@ -229,7 +230,7 @@ private WriterFailoverResult getNextResult( } catch (final ExecutionException e) { // return failure below } - return new WriterFailoverResult(false, false, null, null, null, "None"); + return DEFAULT_RESULT; } private void logTaskSuccess(final WriterFailoverResult result) { @@ -324,7 +325,7 @@ public WriterFailoverResult call() { () -> Messages.get( "ClusterAwareWriterFailoverHandler.taskAEncounteredException", new Object[] {exception})); - return new WriterFailoverResult(false, false, null, this.availabilityMap, null, "TaskA", exception); + return new WriterFailoverResult(false, false, null, null, "TaskA", exception); } } @@ -335,14 +336,14 @@ public WriterFailoverResult call() { success = isCurrentHostWriter(latestTopology); LOGGER.finest("[TaskA] success: " + success); - availabilityMap.put(this.originalWriterHost.getHost(), HostAvailability.AVAILABLE); - return new WriterFailoverResult(success, false, latestTopology, this.availabilityMap, success ? conn : null, "TaskA"); + this.availabilityMap.put(this.originalWriterHost.getHost(), HostAvailability.AVAILABLE); + return new WriterFailoverResult(success, false, latestTopology, success ? conn : null, "TaskA"); } catch (final InterruptedException exception) { Thread.currentThread().interrupt(); - return new WriterFailoverResult(success, false, latestTopology, this.availabilityMap, success ? conn : null, "TaskA"); + return new WriterFailoverResult(success, false, latestTopology, success ? conn : null, "TaskA"); } catch (final Exception ex) { LOGGER.severe(ex::getMessage); - return new WriterFailoverResult(false, false, null, this.availabilityMap, null, "TaskA"); + return new WriterFailoverResult(false, false, null, null, "TaskA"); } finally { try { if (conn != null && !success && !conn.isClosed()) { @@ -421,12 +422,11 @@ public WriterFailoverResult call() { true, true, this.currentTopology, - this.availabilityMap, this.currentConnection, "TaskB"); } catch (final InterruptedException exception) { Thread.currentThread().interrupt(); - return new WriterFailoverResult(false, false, null, this.availabilityMap, null, "TaskB"); + return new WriterFailoverResult(false, false, null, null, "TaskB"); } catch (final Exception ex) { LOGGER.severe( () -> Messages.get( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index a423a27c9..4b07f0996 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -620,6 +620,23 @@ protected void dealWithIllegalStateException( */ protected void failover(final HostSpec failedHost) throws SQLException { this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); + if (this.connectionService == null) { + TargetDriverHelper helper = new TargetDriverHelper(); + java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); + final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); + this.connectionService = new ConnectionServiceImpl( + servicesContainer.getStorageService(), + servicesContainer.getMonitorService(), + servicesContainer.getTelemetryFactory(), + defaultConnectionProvider, + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect(), + properties + ); + } + if (this.failoverMode == FailoverMode.STRICT_WRITER) { failoverWriter(); } else { @@ -635,6 +652,13 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException this.failoverReaderTriggeredCounter.inc(); } + if (this.readerFailoverHandler == null) { + if (this.readerFailoverHandlerSupplier == null) { + throw new SQLException(Messages.get("Failover.nullReaderFailoverHandlerSupplier")); + } + this.readerFailoverHandler = this.readerFailoverHandlerSupplier.apply(this.connectionService); + } + final long failoverStartNano = System.nanoTime(); try { @@ -646,14 +670,14 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException failedHost = failedHostSpec; } - final ReaderFailoverResult result = readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); + final ReaderFailoverResult result = this.readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); if (result != null) { final SQLException exception = result.getException(); if (exception != null) { throw exception; } - updateHostAvailability(result.getHostAvailabilityMap()); + updateHostAvailability(this.readerFailoverHandler.getHostAvailabilityMap()); } if (result == null || !result.isConnected()) { @@ -693,6 +717,7 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException LOGGER.finest(() -> Messages.get( "Failover.readerFailoverElapsed", new Object[]{TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - failoverStartNano)})); + this.readerFailoverHandler = null; if (telemetryContext != null) { telemetryContext.closeContext(); if (this.telemetryFailoverAdditionalTopTraceSetting) { @@ -728,6 +753,13 @@ protected void failoverWriter() throws SQLException { this.failoverWriterTriggeredCounter.inc(); } + if (this.writerFailoverHandler == null) { + if (this.writerFailoverHandlerSupplier == null) { + throw new SQLException(Messages.get("Failover.nullWriterFailoverHandlerSupplier")); + } + this.writerFailoverHandler = this.writerFailoverHandlerSupplier.apply(this.connectionService); + } + long failoverStartTimeNano = System.nanoTime(); try { @@ -740,7 +772,7 @@ protected void failoverWriter() throws SQLException { throw exception; } - updateHostAvailability(failoverResult.getHostAvailabilityMap()); + updateHostAvailability(this.writerFailoverHandler.getHostAvailabilityMap()); } if (failoverResult == null || !failoverResult.isConnected()) { @@ -797,6 +829,7 @@ protected void failoverWriter() throws SQLException { LOGGER.finest(() -> Messages.get( "Failover.writerFailoverElapsed", new Object[]{TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - failoverStartTimeNano)})); + this.writerFailoverHandler = null; if (telemetryContext != null) { telemetryContext.closeContext(); if (this.telemetryFailoverAdditionalTopTraceSetting) { @@ -807,15 +840,15 @@ protected void failoverWriter() throws SQLException { } private void updateHostAvailability(Map hostAvailabilityMap) { - if (hostAvailabilityMap != null && !hostAvailabilityMap.isEmpty()) { - List allHosts = this.pluginService.getAllHosts(); - for (HostSpec host : allHosts) { - for (String alias : host.getAliases()) { - HostAvailability availability = hostAvailabilityMap.get(alias); - if (availability != null) { - host.setAvailability(availability); - } - } + if (hostAvailabilityMap == null || hostAvailabilityMap.isEmpty()) { + return; + } + + List allHosts = this.pluginService.getAllHosts(); + for (HostSpec host : allHosts) { + HostAvailability availability = hostAvailabilityMap.get(host.getHost()); + if (availability != null) { + host.setAvailability(availability); } } } @@ -908,37 +941,7 @@ public Connection connect( final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - if (this.connectionService == null) { - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); - this.connectionService = new ConnectionServiceImpl( - servicesContainer.getStorageService(), - servicesContainer.getMonitorService(), - servicesContainer.getTelemetryFactory(), - defaultConnectionProvider, - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - properties - ); - } - this.initFailoverMode(); - if (this.readerFailoverHandler == null) { - if (this.readerFailoverHandlerSupplier == null) { - throw new SQLException(Messages.get("Failover.nullReaderFailoverHandlerSupplier")); - } - this.readerFailoverHandler = this.readerFailoverHandlerSupplier.apply(this.connectionService); - } - - if (this.writerFailoverHandler == null) { - if (this.writerFailoverHandlerSupplier == null) { - throw new SQLException(Messages.get("Failover.nullWriterFailoverHandlerSupplier")); - } - this.writerFailoverHandler = this.writerFailoverHandlerSupplier.apply(this.connectionService); - } Connection conn = null; try { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java index e006558b6..0f807fd54 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverHandler.java @@ -18,7 +18,9 @@ import java.sql.SQLException; import java.util.List; +import java.util.Map; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.hostavailability.HostAvailability; /** * Interface for Reader Failover Process handler. This handler implements all necessary logic to try @@ -47,4 +49,12 @@ public interface ReaderFailoverHandler { * @throws SQLException if any error occurred while attempting a reader connection. */ ReaderFailoverResult getReaderConnection(List hostList) throws SQLException; + + /** + * Get the host availability map for the failover handler. This map will be populated with host availability + * information during the failover process and can be used to determine which hosts are available. + * + * @return the host availability map for the failover handler. + */ + Map getHostAvailabilityMap(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java index 2ada37ca2..1367cc996 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ReaderFailoverResult.java @@ -18,9 +18,7 @@ import java.sql.Connection; import java.sql.SQLException; -import java.util.Map; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.hostavailability.HostAvailability; /** * This class holds results of Reader Failover Process. @@ -32,27 +30,23 @@ public class ReaderFailoverResult { private final boolean isConnected; private final SQLException exception; private final HostSpec newHost; - private final Map hostAvailabilityMap; public ReaderFailoverResult( final Connection newConnection, final HostSpec newHost, - final boolean isConnected, - final Map hostAvailabilityMap) { - this(newConnection, newHost, isConnected, null, hostAvailabilityMap); + final boolean isConnected) { + this(newConnection, newHost, isConnected, null); } public ReaderFailoverResult( final Connection newConnection, final HostSpec newHost, final boolean isConnected, - final SQLException exception, - final Map hostAvailabilityMap) { + final SQLException exception) { this.newConnection = newConnection; this.newHost = newHost; this.isConnected = isConnected; this.exception = exception; - this.hostAvailabilityMap = hostAvailabilityMap; } /** @@ -90,8 +84,4 @@ public boolean isConnected() { public SQLException getException() { return exception; } - - public Map getHostAvailabilityMap() { - return hostAvailabilityMap; - } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java index caae8de8c..f05c6a8c0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverHandler.java @@ -18,7 +18,9 @@ import java.sql.SQLException; import java.util.List; +import java.util.Map; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.hostavailability.HostAvailability; /** * Interface for Writer Failover Process handler. This handler implements all necessary logic to try @@ -34,4 +36,12 @@ public interface WriterFailoverHandler { * @throws SQLException indicating whether the failover attempt was successful. */ WriterFailoverResult failover(List currentTopology) throws SQLException; + + /** + * Get the host availability map for the failover handler. This map will be populated with host availability + * information during the failover process and can be used to determine which hosts are available. + * + * @return the host availability map for the failover handler. + */ + Map getHostAvailabilityMap(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java index e79f395a7..b1fc63a33 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/WriterFailoverResult.java @@ -19,9 +19,7 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.List; -import java.util.Map; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.hostavailability.HostAvailability; /** * This class holds results of Writer Failover Process. @@ -31,7 +29,6 @@ public class WriterFailoverResult { private final boolean isConnected; private final boolean isNewHost; private final List topology; - private final Map hostAvailabilityMap; private final Connection newConnection; private final String taskName; private final SQLException exception; @@ -40,24 +37,21 @@ public WriterFailoverResult( final boolean isConnected, final boolean isNewHost, final List topology, - final Map hostAvailabilityMap, final Connection newConnection, final String taskName) { - this(isConnected, isNewHost, topology, hostAvailabilityMap, newConnection, taskName, null); + this(isConnected, isNewHost, topology, newConnection, taskName, null); } public WriterFailoverResult( final boolean isConnected, final boolean isNewHost, final List topology, - final Map hostAvailabilityMap, final Connection newConnection, final String taskName, final SQLException exception) { this.isConnected = isConnected; this.isNewHost = isNewHost; this.topology = topology; - this.hostAvailabilityMap = hostAvailabilityMap; this.newConnection = newConnection; this.taskName = taskName; this.exception = exception; @@ -92,10 +86,6 @@ public List getTopology() { return this.topology; } - public Map getHostAvailabilityMap() { - return this.hostAvailabilityMap; - } - /** * Get the new connection established by the failover procedure if successful. * From 4df33da212cef40d4ec90085711209dafe7def59 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 22 Aug 2025 13:13:47 -0700 Subject: [PATCH 12/54] Fix checkstyle --- .../amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 4b07f0996..285dc0eac 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -670,7 +670,8 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException failedHost = failedHostSpec; } - final ReaderFailoverResult result = this.readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); + final ReaderFailoverResult result = + this.readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); if (result != null) { final SQLException exception = result.getException(); if (exception != null) { From 7a863c9496f4ec5eca9d3ceca866eadf4cec3fc1 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 22 Aug 2025 14:37:41 -0700 Subject: [PATCH 13/54] ReaderFailoverHandler tests passing --- .../ClusterAwareReaderFailoverHandler.java | 2 +- ...ClusterAwareReaderFailoverHandlerTest.java | 808 +++++++++--------- 2 files changed, 402 insertions(+), 408 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 4cb587f55..8f3ab2f62 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -394,7 +394,7 @@ private ReaderFailoverResult getNextResult(final CompletionService defaultHosts = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer").port(1234).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader1").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader2").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader3").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader4").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader5").port(1234).role(HostRole.READER).build() -// ); -// -// @BeforeEach -// void setUp() { -// closeable = MockitoAnnotations.openMocks(this); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// } -// -// @Test -// public void testFailover() throws SQLException { -// // original host list: [active writer, active reader, current connection (reader), active -// // reader, down reader, active reader] -// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] -// // connection attempts are made in pairs using the above list -// // expected test result: successful connection for host at index 4 -// final List hosts = defaultHosts; -// final int currentHostIndex = 2; -// final int successHostIndex = 4; -// for (int i = 0; i < hosts.size(); i++) { -// if (i != successHostIndex) { -// final SQLException exception = new SQLException("exception", "08S01", null); -// when(mockPluginService.forceConnect(hosts.get(i), properties)) -// .thenThrow(exception); -// when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); -// } else { -// when(mockPluginService.forceConnect(hosts.get(i), properties)).thenReturn(mockConnection); -// } -// } -// when(mockPluginService.getTargetDriverDialect()).thenReturn(null); -// -// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// final ReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties); -// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); -// -// assertTrue(result.isConnected()); -// assertSame(mockConnection, result.getConnection()); -// assertEquals(hosts.get(successHostIndex), result.getHost()); -// -// final HostSpec successHost = hosts.get(successHostIndex); -// verify(mockPluginService, atLeast(4)).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); -// verify(mockPluginService, never()) -// .setAvailability(eq(successHost.asAliases()), eq(HostAvailability.NOT_AVAILABLE)); -// verify(mockPluginService, times(1)) -// .setAvailability(eq(successHost.asAliases()), eq(HostAvailability.AVAILABLE)); -// } -// -// @Test -// public void testFailover_timeout() throws SQLException { -// // original host list: [active writer, active reader, current connection (reader), active -// // reader, down reader, active reader] -// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] -// // connection attempts are made in pairs using the above list -// // expected test result: failure to get reader since process is limited to 5s and each attempt -// // to connect takes 20s -// final List hosts = defaultHosts; -// final int currentHostIndex = 2; -// for (HostSpec host : hosts) { -// when(mockPluginService.forceConnect(host, properties)) -// .thenAnswer((Answer) invocation -> { -// Thread.sleep(20000); -// return mockConnection; -// }); -// } -// -// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// final ReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties, -// 5000, -// 30000, -// false); -// -// final long startTimeNano = System.nanoTime(); -// final ReaderFailoverResult result = -// target.failover(hosts, hosts.get(currentHostIndex)); -// final long durationNano = System.nanoTime() - startTimeNano; -// -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// -// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements -// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); -// } -// -// @Test -// public void testFailover_nullOrEmptyHostList() throws SQLException { -// final ClusterAwareReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties); -// final HostSpec currentHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer") -// .port(1234).build(); -// -// ReaderFailoverResult result = target.failover(null, currentHost); -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// -// final List hosts = new ArrayList<>(); -// result = target.failover(hosts, currentHost); -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// } -// -// @Test -// public void testGetReader_connectionSuccess() throws SQLException { -// // even number of connection attempts -// // first connection attempt to return succeeds, second attempt cancelled -// // expected test result: successful connection for host at index 2 -// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) -// final HostSpec slowHost = hosts.get(1); -// final HostSpec fastHost = hosts.get(2); -// when(mockPluginService.forceConnect(slowHost, properties)) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(20000); -// return mockConnection; -// }); -// when(mockPluginService.forceConnect(eq(fastHost), eq(properties))).thenReturn(mockConnection); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties); -// final ReaderFailoverResult result = target.getReaderConnection(hosts); -// -// assertTrue(result.isConnected()); -// assertSame(mockConnection, result.getConnection()); -// assertEquals(hosts.get(2), result.getHost()); -// -// verify(mockPluginService, never()).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); -// verify(mockPluginService, times(1)) -// .setAvailability(eq(fastHost.asAliases()), eq(HostAvailability.AVAILABLE)); -// } -// -// @Test -// public void testGetReader_connectionFailure() throws SQLException { -// // odd number of connection attempts -// // first connection attempt to return fails -// // expected test result: failure to get reader -// final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) -// when(mockPluginService.forceConnect(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final int currentHostIndex = 2; -// -// final ReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties); -// final ReaderFailoverResult result = target.getReaderConnection(hosts); -// -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// } -// -// @Test -// public void testGetReader_connectionAttemptsTimeout() throws SQLException { -// // connection attempts time out before they can succeed -// // first connection attempt to return times out -// // expected test result: failure to get reader -// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) -// when(mockPluginService.forceConnect(any(), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// try { -// Thread.sleep(5000); -// } catch (InterruptedException exception) { -// // ignore -// } -// return mockConnection; -// }); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ClusterAwareReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties, -// 60000, -// 1000, -// false); -// final ReaderFailoverResult result = target.getReaderConnection(hosts); -// -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// } -// -// @Test -// public void testGetHostTuplesByPriority() { -// final List originalHosts = defaultHosts; -// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// final ClusterAwareReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties); -// final List hostsByPriority = target.getHostsByPriority(originalHosts); -// -// int i = 0; -// -// // expecting active readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { -// i++; -// } -// -// // expecting a writer -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.WRITER) { -// i++; -// } -// -// // expecting down readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { -// i++; -// } -// -// assertEquals(hostsByPriority.size(), i); -// } -// -// @Test -// public void testGetReaderTuplesByPriority() { -// final List originalHosts = defaultHosts; -// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ClusterAwareReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties); -// final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); -// -// int i = 0; -// -// // expecting active readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { -// i++; -// } -// -// // expecting down readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { -// i++; -// } -// -// assertEquals(hostsByPriority.size(), i); -// } -// -// @Test -// public void testHostFailoverStrictReaderEnabled() { -// -// final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer").port(1234).role(HostRole.WRITER).build(); -// final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader1").port(1234).role(HostRole.READER).build(); -// final List hosts = Arrays.asList(writer, reader); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// final ClusterAwareReaderFailoverHandler target = -// new ClusterAwareReaderFailoverHandler( -// mockPluginService, -// properties, -// DEFAULT_FAILOVER_TIMEOUT, -// DEFAULT_READER_CONNECT_TIMEOUT, -// true); -// -// // The writer is included because the original writer has likely become a reader. -// List expectedHostsByPriority = Arrays.asList(reader, writer); -// -// List hostsByPriority = target.getHostsByPriority(hosts); -// assertEquals(expectedHostsByPriority, hostsByPriority); -// -// // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. -// reader.setAvailability(HostAvailability.NOT_AVAILABLE); -// expectedHostsByPriority = Arrays.asList(writer, reader); -// -// hostsByPriority = target.getHostsByPriority(hosts); -// assertEquals(expectedHostsByPriority, hostsByPriority); -// -// // Writer node will only be picked if it is the only node in topology; -// List expectedWriterHost = Collections.singletonList(writer); -// -// hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); -// assertEquals(expectedWriterHost, hostsByPriority); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.failover; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; +import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.stubbing.Answer; +import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionService; + +class ClusterAwareReaderFailoverHandlerTest { + @Mock FullServicesContainer mockContainer; + @Mock ConnectionService mockConnectionService; + @Mock PluginService mockPluginService; + @Mock ConnectionPluginManager mockPluginManager; + @Mock Connection mockConnection; + + private AutoCloseable closeable; + private final Properties properties = new Properties(); + private final List defaultHosts = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer").port(1234).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader1").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader2").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader3").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader4").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader5").port(1234).role(HostRole.READER).build() + ); + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); + when(mockContainer.getPluginService()).thenReturn(mockPluginService); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + public void testFailover() throws SQLException { + // original host list: [active writer, active reader, current connection (reader), active + // reader, down reader, active reader] + // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] + // connection attempts are made in pairs using the above list + // expected test result: successful connection for host at index 4 + final List hosts = defaultHosts; + final int currentHostIndex = 2; + final int successHostIndex = 4; + for (int i = 0; i < hosts.size(); i++) { + if (i != successHostIndex) { + final SQLException exception = new SQLException("exception", "08S01", null); + when(mockConnectionService.open(hosts.get(i), properties)) + .thenThrow(exception); + when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); + } else { + when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); + } + } + + when(mockPluginService.getTargetDriverDialect()).thenReturn(null); + + hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + + final ReaderFailoverHandler target = getSpyFailoverHandler(); + final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); + + assertTrue(result.isConnected()); + assertSame(mockConnection, result.getConnection()); + assertEquals(hosts.get(successHostIndex), result.getHost()); + + final HostSpec successHost = hosts.get(successHostIndex); + final Map availabilityMap = target.getHostAvailabilityMap(); + Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); + assertTrue(unavailableHosts.size() >= 4); + assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); + } + + private Set getHostsWithGivenAvailability( + Map availabilityMap, HostAvailability availability) { + return availabilityMap.entrySet().stream() + .filter((entry) -> availability.equals(entry.getValue())) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + + @Test + public void testFailover_timeout() throws SQLException { + // original host list: [active writer, active reader, current connection (reader), active + // reader, down reader, active reader] + // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] + // connection attempts are made in pairs using the above list + // expected test result: failure to get reader since process is limited to 5s and each attempt + // to connect takes 20s + final List hosts = defaultHosts; + final int currentHostIndex = 2; + for (HostSpec host : hosts) { + when(mockConnectionService.open(host, properties)) + .thenAnswer((Answer) invocation -> { + Thread.sleep(20000); + return mockConnection; + }); + } + + hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + + final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); + + final long startTimeNano = System.nanoTime(); + final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); + final long durationNano = System.nanoTime() - startTimeNano; + + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + + // 5s is a max allowed failover timeout; add 1s for inaccurate measurements + assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); + } + + private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() { + ClusterAwareReaderFailoverHandler handler = + spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); + doReturn(mockPluginService).when(handler).getNewPluginService(); + return handler; + } + + private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( + int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) { + ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( + mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); + ClusterAwareReaderFailoverHandler spyHandler = spy(handler); + doReturn(mockPluginService).when(spyHandler).getNewPluginService(); + return spyHandler; + } + + @Test + public void testFailover_nullOrEmptyHostList() throws SQLException { + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); + final HostSpec currentHost = + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); + + ReaderFailoverResult result = target.failover(null, currentHost); + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + + final List hosts = new ArrayList<>(); + result = target.failover(hosts, currentHost); + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + } + + @Test + public void testGetReader_connectionSuccess() throws SQLException { + // even number of connection attempts + // first connection attempt to return succeeds, second attempt cancelled + // expected test result: successful connection for host at index 2 + final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) + final HostSpec slowHost = hosts.get(1); + final HostSpec fastHost = hosts.get(2); + when(mockConnectionService.open(slowHost, properties)) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(20000); + return mockConnection; + }); + when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ReaderFailoverHandler target = getSpyFailoverHandler(); + final ReaderFailoverResult result = target.getReaderConnection(hosts); + + assertTrue(result.isConnected()); + assertSame(mockConnection, result.getConnection()); + assertEquals(hosts.get(2), result.getHost()); + + Map availabilityMap = target.getHostAvailabilityMap(); + assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); + assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); + } + + @Test + public void testGetReader_connectionFailure() throws SQLException { + // odd number of connection attempts + // first connection attempt to return fails + // expected test result: failure to get reader + final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) + when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ReaderFailoverHandler target = getSpyFailoverHandler(); + final ReaderFailoverResult result = target.getReaderConnection(hosts); + + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + } + + @Test + public void testGetReader_connectionAttemptsTimeout() throws SQLException { + // connection attempts time out before they can succeed + // first connection attempt to return times out + // expected test result: failure to get reader + final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) + when(mockConnectionService.open(any(), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + try { + Thread.sleep(5000); + } catch (InterruptedException exception) { + // ignore + } + return mockConnection; + }); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); + final ReaderFailoverResult result = target.getReaderConnection(hosts); + + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + } + + @Test + public void testGetHostTuplesByPriority() { + final List originalHosts = defaultHosts; + originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); + + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); + final List hostsByPriority = target.getHostsByPriority(originalHosts); + + int i = 0; + + // expecting active readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { + i++; + } + + // expecting a writer + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.WRITER) { + i++; + } + + // expecting down readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { + i++; + } + + assertEquals(hostsByPriority.size(), i); + } + + @Test + public void testGetReaderTuplesByPriority() { + final List originalHosts = defaultHosts; + originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); + final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); + + int i = 0; + + // expecting active readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { + i++; + } + + // expecting down readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { + i++; + } + + assertEquals(hostsByPriority.size(), i); + } + + @Test + public void testHostFailoverStrictReaderEnabled() { + + final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer").port(1234).role(HostRole.WRITER).build(); + final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader1").port(1234).role(HostRole.READER).build(); + final List hosts = Arrays.asList(writer, reader); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ClusterAwareReaderFailoverHandler target = + getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); + + // The writer is included because the original writer has likely become a reader. + List expectedHostsByPriority = Arrays.asList(reader, writer); + + List hostsByPriority = target.getHostsByPriority(hosts); + assertEquals(expectedHostsByPriority, hostsByPriority); + + // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. + reader.setAvailability(HostAvailability.NOT_AVAILABLE); + expectedHostsByPriority = Arrays.asList(writer, reader); + + hostsByPriority = target.getHostsByPriority(hosts); + assertEquals(expectedHostsByPriority, hostsByPriority); + + // Writer node will only be picked if it is the only node in topology; + List expectedWriterHost = Collections.singletonList(writer); + + hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); + assertEquals(expectedWriterHost, hostsByPriority); + } +} From 370d73fbf327ed13f7e3942cec980b11333acec7 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 22 Aug 2025 15:06:33 -0700 Subject: [PATCH 14/54] WriterFailoverHandler tests passing --- .../ClusterAwareWriterFailoverHandler.java | 2 +- ...ClusterAwareWriterFailoverHandlerTest.java | 781 +++++++++--------- 2 files changed, 374 insertions(+), 409 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 8533b1055..1921f3238 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -189,7 +189,7 @@ private void submitTasks( executorService.shutdown(); } - private PluginService getNewPluginService() { + protected PluginService getNewPluginService() { // Each task should get its own PluginService since they execute concurrently and PluginService was not designed to // be thread-safe. return new PartialPluginService( diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java index 5f302209a..1ad394010 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java @@ -1,408 +1,373 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.failover; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertFalse; -// import static org.junit.jupiter.api.Assertions.assertSame; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.ArgumentMatchers.refEq; -// import static org.mockito.Mockito.atLeastOnce; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.Arrays; -// import java.util.EnumSet; -// import java.util.List; -// import java.util.Properties; -// import java.util.concurrent.TimeUnit; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.ArgumentMatchers; -// import org.mockito.InOrder; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import org.mockito.stubbing.Answer; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// -// class ClusterAwareWriterFailoverHandlerTest { -// -// @Mock PluginService mockPluginService; -// @Mock Connection mockConnection; -// @Mock ReaderFailoverHandler mockReaderFailover; -// @Mock Connection mockWriterConnection; -// @Mock Connection mockNewWriterConnection; -// @Mock Connection mockReaderAConnection; -// @Mock Connection mockReaderBConnection; -// @Mock Dialect mockDialect; -// -// private AutoCloseable closeable; -// private final Properties properties = new Properties(); -// private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("new-writer-host").build(); -// private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer-host").build(); -// private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader-a-host").build(); -// private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader-b-host").build(); -// private final List topology = Arrays.asList(writer, readerA, readerB); -// private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); -// -// @BeforeEach -// void setUp() { -// closeable = MockitoAnnotations.openMocks(this); -// writer.addAlias("writer-host"); -// newWriterHost.addAlias("new-writer-host"); -// readerA.addAlias("reader-a-host"); -// readerB.addAlias("reader-b-host"); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// } -// -// @Test -// public void testReconnectToWriter_taskBReaderException() throws SQLException { -// when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockConnection); -// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenThrow(SQLException.class); -// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); -// -// when(mockPluginService.getAllHosts()).thenReturn(topology); -// -// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = -// new ClusterAwareWriterFailoverHandler( -// mockPluginService, -// mockReaderFailover, -// properties, -// 5000, -// 2000, -// 2000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertFalse(result.isNewHost()); -// assertSame(result.getNewConnection(), mockConnection); -// -// final InOrder inOrder = Mockito.inOrder(mockPluginService); -// inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); -// } -// -// /** -// * Verify that writer failover handler can re-connect to a current writer node. -// * -// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. -// * TaskA: successfully re-connect to initial writer; return new connection. -// * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. -// * Expected test result: new connection by taskA. -// */ -// @Test -// public void testReconnectToWriter_SlowReaderA() throws SQLException { -// when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); -// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); -// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); -// when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); -// -// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return new ReaderFailoverResult(mockReaderAConnection, readerA, true); -// }); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = -// new ClusterAwareWriterFailoverHandler( -// mockPluginService, -// mockReaderFailover, -// properties, -// 60000, -// 5000, -// 5000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertFalse(result.isNewHost()); -// assertSame(result.getNewConnection(), mockWriterConnection); -// -// final InOrder inOrder = Mockito.inOrder(mockPluginService); -// inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); -// } -// -// /** -// * Verify that writer failover handler can re-connect to a current writer node. -// * -// *

Topology: no changes. -// * TaskA: successfully re-connect to writer; return new connection. -// * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). -// * Expected test result: new connection by taskA. -// */ -// @Test -// public void testReconnectToWriter_taskBDefers() throws SQLException { -// when(mockPluginService.forceConnect(refEq(writer), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return mockWriterConnection; -// }); -// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); -// -// when(mockPluginService.getAllHosts()).thenReturn(topology); -// -// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = -// new ClusterAwareWriterFailoverHandler( -// mockPluginService, -// mockReaderFailover, -// properties, -// 60000, -// 2000, -// 2000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertFalse(result.isNewHost()); -// assertSame(result.getNewConnection(), mockWriterConnection); -// -// final InOrder inOrder = Mockito.inOrder(mockPluginService); -// inOrder.verify(mockPluginService).setAvailability(eq(writer.asAliases()), eq(HostAvailability.AVAILABLE)); -// } -// -// /** -// * Verify that writer failover handler can re-connect to a new writer node. -// * -// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. -// * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more -// * time than taskB. -// * TaskB: successfully connect to readerA and then to new-writer. -// * Expected test result: new connection to writer by taskB. -// */ -// @Test -// public void testConnectToReaderA_SlowWriter() throws SQLException { -// when(mockPluginService.forceConnect(refEq(writer), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return mockWriterConnection; -// }); -// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); -// -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = -// new ClusterAwareWriterFailoverHandler( -// mockPluginService, -// mockReaderFailover, -// properties, -// 60000, -// 5000, -// 5000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertTrue(result.isNewHost()); -// assertSame(result.getNewConnection(), mockNewWriterConnection); -// assertEquals(3, result.getTopology().size()); -// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); -// -// verify(mockPluginService, times(1)).setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.AVAILABLE)); -// } -// -// /** -// * Verify that writer failover handler can re-connect to a new writer node. -// * -// *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. -// * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). -// * TaskB: successfully connect to readerA and then to new-writer. -// * Expected test result: new connection to writer by taskB. -// */ -// @Test -// public void testConnectToReaderA_taskADefers() throws SQLException { -// when(mockPluginService.forceConnect(writer, properties)).thenReturn(mockConnection); -// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return mockNewWriterConnection; -// }); -// -// final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = -// new ClusterAwareWriterFailoverHandler( -// mockPluginService, -// mockReaderFailover, -// properties, -// 60000, -// 5000, -// 5000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertTrue(result.isNewHost()); -// assertSame(result.getNewConnection(), mockNewWriterConnection); -// assertEquals(4, result.getTopology().size()); -// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); -// -// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); -// verify(mockPluginService, times(1)).setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.AVAILABLE)); -// } -// -// /** -// * Verify that writer failover handler fails to re-connect to any writer node. -// * -// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. -// * TaskA: fail to re-connect to writer due to failover timeout. -// * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. -// * Expected test result: no connection. -// */ -// @Test -// public void testFailedToConnect_failoverTimeout() throws SQLException { -// when(mockPluginService.forceConnect(refEq(writer), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(30000); -// return mockWriterConnection; -// }); -// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(30000); -// return mockNewWriterConnection; -// }); -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = -// new ClusterAwareWriterFailoverHandler( -// mockPluginService, -// mockReaderFailover, -// properties, -// 5000, -// 2000, -// 2000); -// -// final long startTimeNano = System.nanoTime(); -// final WriterFailoverResult result = target.failover(topology); -// final long durationNano = System.nanoTime() - startTimeNano; -// -// assertFalse(result.isConnected()); -// assertFalse(result.isNewHost()); -// -// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); -// -// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements -// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); -// } -// -// /** -// * Verify that writer failover handler fails to re-connect to any writer node. -// * -// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. -// * TaskA: fail to re-connect to writer due to exception. -// * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. -// * Expected test result: no connection. -// */ -// @Test -// public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { -// final SQLException exception = new SQLException("exception", "08S01", null); -// when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenThrow(exception); -// when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenThrow(exception); -// when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); -// -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = -// new ClusterAwareWriterFailoverHandler( -// mockPluginService, -// mockReaderFailover, -// properties, -// 5000, -// 2000, -// 2000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertFalse(result.isConnected()); -// assertFalse(result.isNewHost()); -// -// verify(mockPluginService, atLeastOnce()) -// .setAvailability(eq(newWriterHost.asAliases()), eq(HostAvailability.NOT_AVAILABLE)); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.failover; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.refEq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.stubbing.Answer; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionService; + +class ClusterAwareWriterFailoverHandlerTest { + @Mock FullServicesContainer mockContainer; + @Mock ConnectionService mockConnectionService; + @Mock PluginService mockPluginService; + @Mock Connection mockConnection; + @Mock ReaderFailoverHandler mockReaderFailoverHandler; + @Mock Connection mockWriterConnection; + @Mock Connection mockNewWriterConnection; + @Mock Connection mockReaderAConnection; + @Mock Connection mockReaderBConnection; + @Mock Dialect mockDialect; + + private AutoCloseable closeable; + private final Properties properties = new Properties(); + private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("new-writer-host").build(); + private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer-host").build(); + private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader-a-host").build(); + private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader-b-host").build(); + private final List topology = Arrays.asList(writer, readerA, readerB); + private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + when(mockContainer.getPluginService()).thenReturn(mockPluginService); + writer.addAlias("writer-host"); + newWriterHost.addAlias("new-writer-host"); + readerA.addAlias("reader-a-host"); + readerB.addAlias("reader-b-host"); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + public void testReconnectToWriter_taskBReaderException() throws SQLException { + when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); + when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); + when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); + + when(mockPluginService.getAllHosts()).thenReturn(topology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertFalse(result.isNewHost()); + assertSame(result.getNewConnection(), mockConnection); + + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); + } + + private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( + final int failoverTimeoutMs, + final int readTopologyIntervalMs, + final int reconnectWriterIntervalMs) { + ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( + mockContainer, + mockConnectionService, + mockReaderFailoverHandler, + properties, + failoverTimeoutMs, + readTopologyIntervalMs, + reconnectWriterIntervalMs); + + ClusterAwareWriterFailoverHandler spyHandler = spy(handler); + doReturn(mockPluginService).when(spyHandler).getNewPluginService(); + return spyHandler; + } + + /** + * Verify that writer failover handler can re-connect to a current writer node. + * + *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. + * TaskA: successfully re-connect to initial writer; return new connection. + * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. + * Expected test result: new connection by taskA. + */ + @Test + public void testReconnectToWriter_SlowReaderA() throws SQLException { + when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); + when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); + when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); + when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return new ReaderFailoverResult(mockReaderAConnection, readerA, true); + }); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertFalse(result.isNewHost()); + assertSame(result.getNewConnection(), mockWriterConnection); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); + } + + /** + * Verify that writer failover handler can re-connect to a current writer node. + * + *

Topology: no changes. + * TaskA: successfully re-connect to writer; return new connection. + * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). + * Expected test result: new connection by taskA. + */ + @Test + public void testReconnectToWriter_taskBDefers() throws SQLException { + when(mockConnectionService.open(refEq(writer), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return mockWriterConnection; + }); + when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); + + when(mockPluginService.getAllHosts()).thenReturn(topology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertFalse(result.isNewHost()); + assertSame(result.getNewConnection(), mockWriterConnection); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); + } + + /** + * Verify that writer failover handler can re-connect to a new writer node. + * + *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. + * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more + * time than taskB. + * TaskB: successfully connect to readerA and then to new-writer. + * Expected test result: new connection to writer by taskB. + */ + @Test + public void testConnectToReaderA_SlowWriter() throws SQLException { + when(mockConnectionService.open(refEq(writer), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return mockWriterConnection; + }); + when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); + + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertTrue(result.isNewHost()); + assertSame(result.getNewConnection(), mockNewWriterConnection); + assertEquals(3, result.getTopology().size()); + assertEquals("new-writer-host", result.getTopology().get(0).getHost()); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); + } + + /** + * Verify that writer failover handler can re-connect to a new writer node. + * + *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. + * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). + * TaskB: successfully connect to readerA and then to new-writer. + * Expected test result: new connection to writer by taskB. + */ + @Test + public void testConnectToReaderA_taskADefers() throws SQLException { + when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); + when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return mockNewWriterConnection; + }); + + final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertTrue(result.isNewHost()); + assertSame(result.getNewConnection(), mockNewWriterConnection); + assertEquals(4, result.getTopology().size()); + assertEquals("new-writer-host", result.getTopology().get(0).getHost()); + + verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); + } + + /** + * Verify that writer failover handler fails to re-connect to any writer node. + * + *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. + * TaskA: fail to re-connect to writer due to failover timeout. + * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. + * Expected test result: no connection. + */ + @Test + public void testFailedToConnect_failoverTimeout() throws SQLException { + when(mockConnectionService.open(refEq(writer), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(30000); + return mockWriterConnection; + }); + when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(30000); + return mockNewWriterConnection; + }); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); + + final long startTimeNano = System.nanoTime(); + final WriterFailoverResult result = target.failover(topology); + final long durationNano = System.nanoTime() - startTimeNano; + + assertFalse(result.isConnected()); + assertFalse(result.isNewHost()); + + verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); + + // 5s is a max allowed failover timeout; add 1s for inaccurate measurements + assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); + } + + /** + * Verify that writer failover handler fails to re-connect to any writer node. + * + *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. + * TaskA: fail to re-connect to writer due to exception. + * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. + * Expected test result: no connection. + */ + @Test + public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { + final SQLException exception = new SQLException("exception", "08S01", null); + when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); + when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); + when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); + + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); + final WriterFailoverResult result = target.failover(topology); + + assertFalse(result.isConnected()); + assertFalse(result.isNewHost()); + + assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); + } +} From 435f71a317d817d470d969b2503d199e995a62e1 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 22 Aug 2025 15:19:06 -0700 Subject: [PATCH 15/54] FailoverConnectionPluginTests passing --- .../failover/FailoverConnectionPlugin.java | 32 +- .../FailoverConnectionPluginTest.java | 896 +++++++++--------- 2 files changed, 472 insertions(+), 456 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 285dc0eac..8bbdbd5fd 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -621,20 +621,7 @@ protected void dealWithIllegalStateException( protected void failover(final HostSpec failedHost) throws SQLException { this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); if (this.connectionService == null) { - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); - this.connectionService = new ConnectionServiceImpl( - servicesContainer.getStorageService(), - servicesContainer.getMonitorService(), - servicesContainer.getTelemetryFactory(), - defaultConnectionProvider, - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - properties - ); + this.connectionService = getConnectionService(); } if (this.failoverMode == FailoverMode.STRICT_WRITER) { @@ -644,6 +631,23 @@ protected void failover(final HostSpec failedHost) throws SQLException { } } + protected ConnectionService getConnectionService() throws SQLException { + TargetDriverHelper helper = new TargetDriverHelper(); + java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); + final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); + return new ConnectionServiceImpl( + servicesContainer.getStorageService(), + servicesContainer.getMonitorService(), + servicesContainer.getTelemetryFactory(), + defaultConnectionProvider, + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect(), + properties + ); + } + protected void failoverReader(final HostSpec failedHostSpec) throws SQLException { TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext( diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index 33160333f..adf1cb0e6 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -1,442 +1,454 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.failover; -// -// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.atLeastOnce; -// import static org.mockito.Mockito.doNothing; -// import static org.mockito.Mockito.doThrow; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.EnumSet; -// import java.util.HashMap; -// import java.util.HashSet; -// import java.util.List; -// import java.util.Map; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.ValueSource; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.NodeChangeOptions; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.RdsUrlType; -// import software.amazon.jdbc.util.SqlState; -// import software.amazon.jdbc.util.telemetry.GaugeCallable; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.util.telemetry.TelemetryGauge; -// -// class FailoverConnectionPluginTest { -// -// private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; -// private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; -// private static final Object[] EMPTY_ARGS = {}; -// private final List defaultHosts = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer").port(1234).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader1").port(1234).role(HostRole.READER).build()); -// -// @Mock PluginService mockPluginService; -// @Mock Connection mockConnection; -// @Mock HostSpec mockHostSpec; -// @Mock HostListProviderService mockHostListProviderService; -// @Mock AuroraHostListProvider mockHostListProvider; -// @Mock JdbcCallable mockInitHostProviderFunc; -// @Mock ReaderFailoverHandler mockReaderFailoverHandler; -// @Mock WriterFailoverHandler mockWriterFailoverHandler; -// @Mock ReaderFailoverResult mockReaderResult; -// @Mock WriterFailoverResult mockWriterResult; -// @Mock JdbcCallable mockSqlFunction; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock TelemetryCounter mockTelemetryCounter; -// @Mock TelemetryGauge mockTelemetryGauge; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// -// -// private final Properties properties = new Properties(); -// private FailoverConnectionPlugin plugin; -// private AutoCloseable closeable; -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @BeforeEach -// void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// -// when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); -// when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); -// when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockPluginService.getHosts()).thenReturn(defaultHosts); -// when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); -// when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); -// when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); -// when(mockWriterResult.isConnected()).thenReturn(true); -// when(mockWriterResult.getTopology()).thenReturn(defaultHosts); -// when(mockReaderResult.isConnected()).thenReturn(true); -// -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); -// // noinspection unchecked -// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); -// -// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); -// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); -// -// properties.clear(); -// } -// -// @Test -// void test_notifyNodeListChanged_withFailoverDisabled() { -// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); -// final Map> changes = new HashMap<>(); -// -// initializePlugin(); -// plugin.notifyNodeListChanged(changes); -// -// verify(mockPluginService, never()).getCurrentHostSpec(); -// verify(mockHostSpec, never()).getAliases(); -// } -// -// @Test -// void test_notifyNodeListChanged_withValidConnectionNotInTopology() { -// final Map> changes = new HashMap<>(); -// changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); -// changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); -// -// initializePlugin(); -// plugin.notifyNodeListChanged(changes); -// -// when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); -// -// verify(mockPluginService).getCurrentHostSpec(); -// verify(mockHostSpec, never()).getAliases(); -// } -// -// @Test -// void test_updateTopology() throws SQLException { -// initializePlugin(); -// -// // Test updateTopology with failover disabled -// plugin.setRdsUrlType(RdsUrlType.RDS_PROXY); -// plugin.updateTopology(false); -// verify(mockPluginService, never()).forceRefreshHostList(); -// verify(mockPluginService, never()).refreshHostList(); -// -// // Test updateTopology with no connection -// when(mockPluginService.getCurrentHostSpec()).thenReturn(null); -// plugin.updateTopology(false); -// verify(mockPluginService, never()).forceRefreshHostList(); -// verify(mockPluginService, never()).refreshHostList(); -// -// // Test updateTopology with closed connection -// when(mockConnection.isClosed()).thenReturn(true); -// plugin.updateTopology(false); -// verify(mockPluginService, never()).forceRefreshHostList(); -// verify(mockPluginService, never()).refreshHostList(); -// } -// -// @ParameterizedTest -// @ValueSource(booleans = {true, false}) -// void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { -// -// when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); -// when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); -// when(mockConnection.isClosed()).thenReturn(false); -// initializePlugin(); -// plugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); -// -// plugin.updateTopology(forceUpdate); -// if (forceUpdate) { -// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); -// } else { -// verify(mockPluginService, atLeastOnce()).refreshHostList(); -// } -// } -// -// @Test -// void test_failover_failoverWriter() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(true); -// -// initializePlugin(); -// final FailoverConnectionPlugin spyPlugin = spy(plugin); -// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); -// spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; -// -// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); -// verify(spyPlugin).failoverWriter(); -// } -// -// @Test -// void test_failover_failoverReader() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(false); -// -// initializePlugin(); -// final FailoverConnectionPlugin spyPlugin = spy(plugin); -// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); -// spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; -// -// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); -// verify(spyPlugin).failoverReader(eq(mockHostSpec)); -// } -// -// @Test -// void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); -// when(mockReaderResult.isConnected()).thenReturn(true); -// when(mockReaderResult.getConnection()).thenReturn(mockConnection); -// when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); -// -// initializePlugin(); -// plugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// () -> mockReaderFailoverHandler, -// () -> mockWriterFailoverHandler); -// -// final FailoverConnectionPlugin spyPlugin = spy(plugin); -// doNothing().when(spyPlugin).updateTopology(true); -// -// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); -// -// verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); -// verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); -// } -// -// @Test -// void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { -// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") -// .build(); -// final List hosts = Collections.singletonList(hostSpec); -// -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); -// when(mockPluginService.getAllHosts()).thenReturn(hosts); -// when(mockPluginService.getHosts()).thenReturn(hosts); -// when(mockReaderResult.getException()).thenReturn(new SQLException()); -// when(mockReaderResult.getHost()).thenReturn(hostSpec); -// -// initializePlugin(); -// plugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// () -> mockReaderFailoverHandler, -// () -> mockWriterFailoverHandler); -// -// assertThrows(SQLException.class, () -> plugin.failoverReader(null)); -// verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); -// } -// -// @Test -// void test_failoverWriter_failedFailover_throwsException() throws SQLException { -// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") -// .build(); -// final List hosts = Collections.singletonList(hostSpec); -// -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockPluginService.getAllHosts()).thenReturn(hosts); -// when(mockPluginService.getHosts()).thenReturn(hosts); -// when(mockWriterResult.getException()).thenReturn(new SQLException()); -// -// initializePlugin(); -// plugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// () -> mockReaderFailoverHandler, -// () -> mockWriterFailoverHandler); -// -// assertThrows(SQLException.class, () -> plugin.failoverWriter()); -// verify(mockWriterFailoverHandler).failover(eq(hosts)); -// } -// -// @Test -// void test_failoverWriter_failedFailover_withNoResult() throws SQLException { -// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") -// .build(); -// final List hosts = Collections.singletonList(hostSpec); -// -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockPluginService.getAllHosts()).thenReturn(hosts); -// when(mockPluginService.getHosts()).thenReturn(hosts); -// when(mockWriterResult.isConnected()).thenReturn(false); -// -// initializePlugin(); -// plugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// () -> mockReaderFailoverHandler, -// () -> mockWriterFailoverHandler); -// -// final SQLException exception = assertThrows(SQLException.class, () -> plugin.failoverWriter()); -// assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); -// -// verify(mockWriterFailoverHandler).failover(eq(hosts)); -// verify(mockWriterResult, never()).getNewConnection(); -// verify(mockWriterResult, never()).getTopology(); -// } -// -// @Test -// void test_failoverWriter_successFailover() throws SQLException { -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// -// initializePlugin(); -// plugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// () -> mockReaderFailoverHandler, -// () -> mockWriterFailoverHandler); -// -// final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> plugin.failoverWriter()); -// assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); -// -// verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); -// } -// -// @Test -// void test_invalidCurrentConnection_withNoConnection() { -// when(mockPluginService.getCurrentConnection()).thenReturn(null); -// initializePlugin(); -// plugin.invalidateCurrentConnection(); -// -// verify(mockPluginService, never()).getCurrentHostSpec(); -// } -// -// @Test -// void test_invalidateCurrentConnection_inTransaction() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(true); -// when(mockHostSpec.getHost()).thenReturn("host"); -// when(mockHostSpec.getPort()).thenReturn(123); -// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); -// -// initializePlugin(); -// plugin.invalidateCurrentConnection(); -// verify(mockConnection).rollback(); -// -// // Assert SQL exceptions thrown during rollback do not get propagated. -// doThrow(new SQLException()).when(mockConnection).rollback(); -// assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); -// } -// -// @Test -// void test_invalidateCurrentConnection_notInTransaction() { -// when(mockPluginService.isInTransaction()).thenReturn(false); -// when(mockHostSpec.getHost()).thenReturn("host"); -// when(mockHostSpec.getPort()).thenReturn(123); -// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); -// -// initializePlugin(); -// plugin.invalidateCurrentConnection(); -// -// verify(mockPluginService).isInTransaction(); -// } -// -// @Test -// void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(false); -// when(mockConnection.isClosed()).thenReturn(false); -// when(mockHostSpec.getHost()).thenReturn("host"); -// when(mockHostSpec.getPort()).thenReturn(123); -// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); -// -// initializePlugin(); -// plugin.invalidateCurrentConnection(); -// -// doThrow(new SQLException()).when(mockConnection).close(); -// assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); -// -// verify(mockConnection, times(2)).isClosed(); -// verify(mockConnection, times(2)).close(); -// } -// -// @Test -// void test_execute_withFailoverDisabled() throws SQLException { -// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); -// initializePlugin(); -// -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// mockSqlFunction, -// EMPTY_ARGS); -// -// verify(mockSqlFunction).call(); -// verify(mockHostListProvider, never()).getRdsUrlType(); -// } -// -// @Test -// void test_execute_withDirectExecute() throws SQLException { -// initializePlugin(); -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// "close", -// mockSqlFunction, -// EMPTY_ARGS); -// verify(mockSqlFunction).call(); -// verify(mockHostListProvider, never()).getRdsUrlType(); -// } -// -// private void initializePlugin() { -// plugin = new FailoverConnectionPlugin(mockPluginService, properties); -// plugin.setWriterFailoverHandler(mockWriterFailoverHandler); -// plugin.setReaderFailoverHandler(mockReaderFailoverHandler); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.failover; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.NodeChangeOptions; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.RdsUrlType; +import software.amazon.jdbc.util.SqlState; +import software.amazon.jdbc.util.connection.ConnectionService; +import software.amazon.jdbc.util.telemetry.GaugeCallable; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +class FailoverConnectionPluginTest { + + private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; + private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; + private static final Object[] EMPTY_ARGS = {}; + private final List defaultHosts = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer").port(1234).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader1").port(1234).role(HostRole.READER).build()); + + @Mock FullServicesContainer mockContainer; + @Mock ConnectionService mockConnectionService; + @Mock PluginService mockPluginService; + @Mock Connection mockConnection; + @Mock HostSpec mockHostSpec; + @Mock HostListProviderService mockHostListProviderService; + @Mock AuroraHostListProvider mockHostListProvider; + @Mock JdbcCallable mockInitHostProviderFunc; + @Mock ReaderFailoverHandler mockReaderFailoverHandler; + @Mock WriterFailoverHandler mockWriterFailoverHandler; + @Mock ReaderFailoverResult mockReaderResult; + @Mock WriterFailoverResult mockWriterResult; + @Mock JdbcCallable mockSqlFunction; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock TelemetryContext mockTelemetryContext; + @Mock TelemetryCounter mockTelemetryCounter; + @Mock TelemetryGauge mockTelemetryGauge; + @Mock TargetDriverDialect mockTargetDriverDialect; + + + private final Properties properties = new Properties(); + private FailoverConnectionPlugin plugin; + private AutoCloseable closeable; + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @BeforeEach + void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + + when(mockContainer.getPluginService()).thenReturn(mockPluginService); + when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); + when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); + when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockPluginService.getHosts()).thenReturn(defaultHosts); + when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); + when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); + when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); + when(mockWriterResult.isConnected()).thenReturn(true); + when(mockWriterResult.getTopology()).thenReturn(defaultHosts); + when(mockReaderResult.isConnected()).thenReturn(true); + + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); + // noinspection unchecked + when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); + + when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); + when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); + + properties.clear(); + } + + @Test + void test_notifyNodeListChanged_withFailoverDisabled() { + properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); + final Map> changes = new HashMap<>(); + + initializePlugin(); + plugin.notifyNodeListChanged(changes); + + verify(mockPluginService, never()).getCurrentHostSpec(); + verify(mockHostSpec, never()).getAliases(); + } + + @Test + void test_notifyNodeListChanged_withValidConnectionNotInTopology() { + final Map> changes = new HashMap<>(); + changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); + changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); + + initializePlugin(); + plugin.notifyNodeListChanged(changes); + + when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); + + verify(mockPluginService).getCurrentHostSpec(); + verify(mockHostSpec, never()).getAliases(); + } + + @Test + void test_updateTopology() throws SQLException { + initializePlugin(); + + // Test updateTopology with failover disabled + plugin.setRdsUrlType(RdsUrlType.RDS_PROXY); + plugin.updateTopology(false); + verify(mockPluginService, never()).forceRefreshHostList(); + verify(mockPluginService, never()).refreshHostList(); + + // Test updateTopology with no connection + when(mockPluginService.getCurrentHostSpec()).thenReturn(null); + plugin.updateTopology(false); + verify(mockPluginService, never()).forceRefreshHostList(); + verify(mockPluginService, never()).refreshHostList(); + + // Test updateTopology with closed connection + when(mockConnection.isClosed()).thenReturn(true); + plugin.updateTopology(false); + verify(mockPluginService, never()).forceRefreshHostList(); + verify(mockPluginService, never()).refreshHostList(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { + + when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); + when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); + when(mockConnection.isClosed()).thenReturn(false); + initializePlugin(); + plugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); + + plugin.updateTopology(forceUpdate); + if (forceUpdate) { + verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); + } else { + verify(mockPluginService, atLeastOnce()).refreshHostList(); + } + } + + @Test + void test_failover_failoverWriter() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(true); + + initializePlugin(); + doThrow(FailoverSuccessSQLException.class).when(plugin).failoverWriter(); + plugin.failoverMode = FailoverMode.STRICT_WRITER; + + assertThrows(FailoverSuccessSQLException.class, () -> plugin.failover(mockHostSpec)); + verify(plugin).failoverWriter(); + } + + @Test + void test_failover_failoverReader() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(false); + + initializePlugin(); + doThrow(FailoverSuccessSQLException.class).when(plugin).failoverReader(eq(mockHostSpec)); + plugin.failoverMode = FailoverMode.READER_OR_WRITER; + + assertThrows(FailoverSuccessSQLException.class, () -> plugin.failover(mockHostSpec)); + verify(plugin).failoverReader(eq(mockHostSpec)); + } + + @Test + void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); + when(mockReaderResult.isConnected()).thenReturn(true); + when(mockReaderResult.getConnection()).thenReturn(mockConnection); + when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); + + initializePlugin(); + plugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + (connectionService) -> mockReaderFailoverHandler, + (connectionService) -> mockWriterFailoverHandler); + + final FailoverConnectionPlugin spyPlugin = spy(plugin); + doNothing().when(spyPlugin).updateTopology(true); + + assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); + + verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); + verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); + } + + @Test + void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { + final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") + .build(); + final List hosts = Collections.singletonList(hostSpec); + + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); + when(mockPluginService.getAllHosts()).thenReturn(hosts); + when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockReaderResult.getException()).thenReturn(new SQLException()); + when(mockReaderResult.getHost()).thenReturn(hostSpec); + + initializePlugin(); + plugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + (connectionService) -> mockReaderFailoverHandler, + (connectionService) -> mockWriterFailoverHandler); + + assertThrows(SQLException.class, () -> plugin.failoverReader(null)); + verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); + } + + @Test + void test_failoverWriter_failedFailover_throwsException() throws SQLException { + final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") + .build(); + final List hosts = Collections.singletonList(hostSpec); + + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockPluginService.getAllHosts()).thenReturn(hosts); + when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockWriterResult.getException()).thenReturn(new SQLException()); + + initializePlugin(); + plugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + (connectionService) -> mockReaderFailoverHandler, + (connectionService) -> mockWriterFailoverHandler); + + assertThrows(SQLException.class, () -> plugin.failoverWriter()); + verify(mockWriterFailoverHandler).failover(eq(hosts)); + } + + @Test + void test_failoverWriter_failedFailover_withNoResult() throws SQLException { + final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") + .build(); + final List hosts = Collections.singletonList(hostSpec); + + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockPluginService.getAllHosts()).thenReturn(hosts); + when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockWriterResult.isConnected()).thenReturn(false); + + initializePlugin(); + plugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + (connectionService) -> mockReaderFailoverHandler, + (connectionService) -> mockWriterFailoverHandler); + + final SQLException exception = assertThrows(SQLException.class, () -> plugin.failoverWriter()); + assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); + + verify(mockWriterFailoverHandler).failover(eq(hosts)); + verify(mockWriterResult, never()).getNewConnection(); + verify(mockWriterResult, never()).getTopology(); + } + + @Test + void test_failoverWriter_successFailover() throws SQLException { + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + + initializePlugin(); + plugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + (connectionService) -> mockReaderFailoverHandler, + (connectionService) -> mockWriterFailoverHandler); + + final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> plugin.failoverWriter()); + assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); + + verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); + } + + @Test + void test_invalidCurrentConnection_withNoConnection() { + when(mockPluginService.getCurrentConnection()).thenReturn(null); + initializePlugin(); + plugin.invalidateCurrentConnection(); + + verify(mockPluginService, never()).getCurrentHostSpec(); + } + + @Test + void test_invalidateCurrentConnection_inTransaction() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockHostSpec.getHost()).thenReturn("host"); + when(mockHostSpec.getPort()).thenReturn(123); + when(mockHostSpec.getRole()).thenReturn(HostRole.READER); + + initializePlugin(); + plugin.invalidateCurrentConnection(); + verify(mockConnection).rollback(); + + // Assert SQL exceptions thrown during rollback do not get propagated. + doThrow(new SQLException()).when(mockConnection).rollback(); + assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); + } + + @Test + void test_invalidateCurrentConnection_notInTransaction() { + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockHostSpec.getHost()).thenReturn("host"); + when(mockHostSpec.getPort()).thenReturn(123); + when(mockHostSpec.getRole()).thenReturn(HostRole.READER); + + initializePlugin(); + plugin.invalidateCurrentConnection(); + + verify(mockPluginService).isInTransaction(); + } + + @Test + void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.isClosed()).thenReturn(false); + when(mockHostSpec.getHost()).thenReturn("host"); + when(mockHostSpec.getPort()).thenReturn(123); + when(mockHostSpec.getRole()).thenReturn(HostRole.READER); + + initializePlugin(); + plugin.invalidateCurrentConnection(); + + doThrow(new SQLException()).when(mockConnection).close(); + assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); + + verify(mockConnection, times(2)).isClosed(); + verify(mockConnection, times(2)).close(); + } + + @Test + void test_execute_withFailoverDisabled() throws SQLException { + properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); + initializePlugin(); + + plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + mockSqlFunction, + EMPTY_ARGS); + + verify(mockSqlFunction).call(); + verify(mockHostListProvider, never()).getRdsUrlType(); + } + + @Test + void test_execute_withDirectExecute() throws SQLException { + initializePlugin(); + plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + "close", + mockSqlFunction, + EMPTY_ARGS); + verify(mockSqlFunction).call(); + verify(mockHostListProvider, never()).getRdsUrlType(); + } + + private void initializePlugin() { + plugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); + plugin.setWriterFailoverHandler(mockWriterFailoverHandler); + plugin.setReaderFailoverHandler(mockReaderFailoverHandler); + + try { + doReturn(mockConnectionService).when(plugin).getConnectionService(); + } catch (SQLException e) { + fail( + "Encountered exception when trying to stub FailoverConnectionPlugin#getConnectionService: " + e.getMessage()); + } + } +} From 6ebf706acd042675d2629dded3b200da14df57b5 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 22 Aug 2025 15:48:24 -0700 Subject: [PATCH 16/54] Rename plugin to spyPlugin in FailoverConnectionPluginTest --- .../FailoverConnectionPluginTest.java | 78 +++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index adf1cb0e6..045482764 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -104,7 +104,7 @@ class FailoverConnectionPluginTest { private final Properties properties = new Properties(); - private FailoverConnectionPlugin plugin; + private FailoverConnectionPlugin spyPlugin; private AutoCloseable closeable; @AfterEach @@ -150,7 +150,7 @@ void test_notifyNodeListChanged_withFailoverDisabled() { final Map> changes = new HashMap<>(); initializePlugin(); - plugin.notifyNodeListChanged(changes); + spyPlugin.notifyNodeListChanged(changes); verify(mockPluginService, never()).getCurrentHostSpec(); verify(mockHostSpec, never()).getAliases(); @@ -163,7 +163,7 @@ void test_notifyNodeListChanged_withValidConnectionNotInTopology() { changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); initializePlugin(); - plugin.notifyNodeListChanged(changes); + spyPlugin.notifyNodeListChanged(changes); when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); @@ -177,20 +177,20 @@ void test_updateTopology() throws SQLException { initializePlugin(); // Test updateTopology with failover disabled - plugin.setRdsUrlType(RdsUrlType.RDS_PROXY); - plugin.updateTopology(false); + spyPlugin.setRdsUrlType(RdsUrlType.RDS_PROXY); + spyPlugin.updateTopology(false); verify(mockPluginService, never()).forceRefreshHostList(); verify(mockPluginService, never()).refreshHostList(); // Test updateTopology with no connection when(mockPluginService.getCurrentHostSpec()).thenReturn(null); - plugin.updateTopology(false); + spyPlugin.updateTopology(false); verify(mockPluginService, never()).forceRefreshHostList(); verify(mockPluginService, never()).refreshHostList(); // Test updateTopology with closed connection when(mockConnection.isClosed()).thenReturn(true); - plugin.updateTopology(false); + spyPlugin.updateTopology(false); verify(mockPluginService, never()).forceRefreshHostList(); verify(mockPluginService, never()).refreshHostList(); } @@ -205,9 +205,9 @@ void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLEx new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); when(mockConnection.isClosed()).thenReturn(false); initializePlugin(); - plugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); + spyPlugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); - plugin.updateTopology(forceUpdate); + spyPlugin.updateTopology(forceUpdate); if (forceUpdate) { verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); } else { @@ -220,11 +220,11 @@ void test_failover_failoverWriter() throws SQLException { when(mockPluginService.isInTransaction()).thenReturn(true); initializePlugin(); - doThrow(FailoverSuccessSQLException.class).when(plugin).failoverWriter(); - plugin.failoverMode = FailoverMode.STRICT_WRITER; + doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); + spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; - assertThrows(FailoverSuccessSQLException.class, () -> plugin.failover(mockHostSpec)); - verify(plugin).failoverWriter(); + assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); + verify(spyPlugin).failoverWriter(); } @Test @@ -232,11 +232,11 @@ void test_failover_failoverReader() throws SQLException { when(mockPluginService.isInTransaction()).thenReturn(false); initializePlugin(); - doThrow(FailoverSuccessSQLException.class).when(plugin).failoverReader(eq(mockHostSpec)); - plugin.failoverMode = FailoverMode.READER_OR_WRITER; + doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); + spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; - assertThrows(FailoverSuccessSQLException.class, () -> plugin.failover(mockHostSpec)); - verify(plugin).failoverReader(eq(mockHostSpec)); + assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); + verify(spyPlugin).failoverReader(eq(mockHostSpec)); } @Test @@ -248,13 +248,13 @@ void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLExc when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); initializePlugin(); - plugin.initHostProvider( + spyPlugin.initHostProvider( mockHostListProviderService, mockInitHostProviderFunc, (connectionService) -> mockReaderFailoverHandler, (connectionService) -> mockWriterFailoverHandler); - final FailoverConnectionPlugin spyPlugin = spy(plugin); + final FailoverConnectionPlugin spyPlugin = spy(this.spyPlugin); doNothing().when(spyPlugin).updateTopology(true); assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); @@ -277,13 +277,13 @@ void test_failoverReader_withNoFailedHostSpec_withException() throws SQLExceptio when(mockReaderResult.getHost()).thenReturn(hostSpec); initializePlugin(); - plugin.initHostProvider( + spyPlugin.initHostProvider( mockHostListProviderService, mockInitHostProviderFunc, (connectionService) -> mockReaderFailoverHandler, (connectionService) -> mockWriterFailoverHandler); - assertThrows(SQLException.class, () -> plugin.failoverReader(null)); + assertThrows(SQLException.class, () -> spyPlugin.failoverReader(null)); verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); } @@ -299,13 +299,13 @@ void test_failoverWriter_failedFailover_throwsException() throws SQLException { when(mockWriterResult.getException()).thenReturn(new SQLException()); initializePlugin(); - plugin.initHostProvider( + spyPlugin.initHostProvider( mockHostListProviderService, mockInitHostProviderFunc, (connectionService) -> mockReaderFailoverHandler, (connectionService) -> mockWriterFailoverHandler); - assertThrows(SQLException.class, () -> plugin.failoverWriter()); + assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); verify(mockWriterFailoverHandler).failover(eq(hosts)); } @@ -321,13 +321,13 @@ void test_failoverWriter_failedFailover_withNoResult() throws SQLException { when(mockWriterResult.isConnected()).thenReturn(false); initializePlugin(); - plugin.initHostProvider( + spyPlugin.initHostProvider( mockHostListProviderService, mockInitHostProviderFunc, (connectionService) -> mockReaderFailoverHandler, (connectionService) -> mockWriterFailoverHandler); - final SQLException exception = assertThrows(SQLException.class, () -> plugin.failoverWriter()); + final SQLException exception = assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); verify(mockWriterFailoverHandler).failover(eq(hosts)); @@ -340,13 +340,13 @@ void test_failoverWriter_successFailover() throws SQLException { when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); initializePlugin(); - plugin.initHostProvider( + spyPlugin.initHostProvider( mockHostListProviderService, mockInitHostProviderFunc, (connectionService) -> mockReaderFailoverHandler, (connectionService) -> mockWriterFailoverHandler); - final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> plugin.failoverWriter()); + final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverWriter()); assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); @@ -356,7 +356,7 @@ void test_failoverWriter_successFailover() throws SQLException { void test_invalidCurrentConnection_withNoConnection() { when(mockPluginService.getCurrentConnection()).thenReturn(null); initializePlugin(); - plugin.invalidateCurrentConnection(); + spyPlugin.invalidateCurrentConnection(); verify(mockPluginService, never()).getCurrentHostSpec(); } @@ -369,12 +369,12 @@ void test_invalidateCurrentConnection_inTransaction() throws SQLException { when(mockHostSpec.getRole()).thenReturn(HostRole.READER); initializePlugin(); - plugin.invalidateCurrentConnection(); + spyPlugin.invalidateCurrentConnection(); verify(mockConnection).rollback(); // Assert SQL exceptions thrown during rollback do not get propagated. doThrow(new SQLException()).when(mockConnection).rollback(); - assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); + assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); } @Test @@ -385,7 +385,7 @@ void test_invalidateCurrentConnection_notInTransaction() { when(mockHostSpec.getRole()).thenReturn(HostRole.READER); initializePlugin(); - plugin.invalidateCurrentConnection(); + spyPlugin.invalidateCurrentConnection(); verify(mockPluginService).isInTransaction(); } @@ -399,10 +399,10 @@ void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { when(mockHostSpec.getRole()).thenReturn(HostRole.READER); initializePlugin(); - plugin.invalidateCurrentConnection(); + spyPlugin.invalidateCurrentConnection(); doThrow(new SQLException()).when(mockConnection).close(); - assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); + assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); verify(mockConnection, times(2)).isClosed(); verify(mockConnection, times(2)).close(); @@ -413,7 +413,7 @@ void test_execute_withFailoverDisabled() throws SQLException { properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); initializePlugin(); - plugin.execute( + spyPlugin.execute( ResultSet.class, SQLException.class, MONITOR_METHOD_INVOKE_ON, @@ -428,7 +428,7 @@ void test_execute_withFailoverDisabled() throws SQLException { @Test void test_execute_withDirectExecute() throws SQLException { initializePlugin(); - plugin.execute( + spyPlugin.execute( ResultSet.class, SQLException.class, MONITOR_METHOD_INVOKE_ON, @@ -440,12 +440,12 @@ void test_execute_withDirectExecute() throws SQLException { } private void initializePlugin() { - plugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); - plugin.setWriterFailoverHandler(mockWriterFailoverHandler); - plugin.setReaderFailoverHandler(mockReaderFailoverHandler); + spyPlugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); + spyPlugin.setWriterFailoverHandler(mockWriterFailoverHandler); + spyPlugin.setReaderFailoverHandler(mockReaderFailoverHandler); try { - doReturn(mockConnectionService).when(plugin).getConnectionService(); + doReturn(mockConnectionService).when(spyPlugin).getConnectionService(); } catch (SQLException e) { fail( "Encountered exception when trying to stub FailoverConnectionPlugin#getConnectionService: " + e.getMessage()); From 7711a9e6c989e3e2b78b7ee19dda3dfccaa66856 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 22 Aug 2025 15:48:40 -0700 Subject: [PATCH 17/54] Fix failing MonitorServiceImplTest tests --- .../util/monitoring/MonitorServiceImpl.java | 43 ++++-- .../monitoring/MonitorServiceImplTest.java | 144 ++++++++++-------- 2 files changed, 107 insertions(+), 80 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index 227b29587..d45729c72 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -42,6 +42,7 @@ import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; +import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.connection.ConnectionServiceImpl; import software.amazon.jdbc.util.events.DataAccessEvent; import software.amazon.jdbc.util.events.Event; @@ -197,20 +198,15 @@ public T runIfAbsent( cacheContainer = monitorCaches.computeIfAbsent(monitorClass, k -> supplier.get()); } - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(originalUrl, originalProps); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); - final Properties propsCopy = PropertyUtils.copyProperties(originalProps); - final ConnectionServiceImpl connectionService = new ConnectionServiceImpl( - storageService, - this, - telemetryFactory, - defaultConnectionProvider, - originalUrl, - driverProtocol, - driverDialect, - dbDialect, - propsCopy); + final ConnectionService connectionService = + getConnectionService( + storageService, + telemetryFactory, + originalUrl, + driverProtocol, + driverDialect, + dbDialect, + originalProps); Monitor monitor = cacheContainer.getCache().computeIfAbsent(key, k -> { MonitorItem monitorItem = new MonitorItem(() -> initializer.createMonitor( @@ -228,6 +224,25 @@ public T runIfAbsent( Messages.get("MonitorServiceImpl.unexpectedMonitorClass", new Object[] {monitorClass, monitor})); } + protected ConnectionService getConnectionService(StorageService storageService, + TelemetryFactory telemetryFactory, String originalUrl, String driverProtocol, TargetDriverDialect driverDialect, + Dialect dbDialect, Properties originalProps) throws SQLException { + TargetDriverHelper helper = new TargetDriverHelper(); + java.sql.Driver driver = helper.getTargetDriver(originalUrl, originalProps); + final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); + final Properties propsCopy = PropertyUtils.copyProperties(originalProps); + return new ConnectionServiceImpl( + storageService, + this, + telemetryFactory, + defaultConnectionProvider, + originalUrl, + driverProtocol, + driverDialect, + dbDialect, + propsCopy); + } + @Override public @Nullable T get(Class monitorClass, Object key) { CacheContainer cacheContainer = monitorCaches.get(monitorClass); diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index cd0bcbe3d..61248faaa 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -21,6 +21,11 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; import java.sql.SQLException; import java.util.Collections; @@ -28,6 +33,7 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -35,39 +41,45 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; class MonitorServiceImplTest { - @Mock StorageService storageService; - @Mock TelemetryFactory telemetryFactory; - @Mock TargetDriverDialect targetDriverDialect; - @Mock Dialect dbDialect; - @Mock EventPublisher publisher; - MonitorServiceImpl monitorService; + @Mock StorageService mockStorageService; + @Mock ConnectionService mockConnectionService; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock Dialect mockDbDialect; + @Mock EventPublisher mockPublisher; + MonitorServiceImpl spyMonitorService; private AutoCloseable closeable; @BeforeEach void setUp() { closeable = MockitoAnnotations.openMocks(this); - monitorService = new MonitorServiceImpl(publisher) { - @Override - protected void initCleanupThread(long cleanupIntervalNanos) { - // Do nothing - } - }; + spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); + doNothing().when(spyMonitorService).initCleanupThread(anyInt()); + + try { + doReturn(mockConnectionService).when(spyMonitorService) + .getConnectionService(any(), any(), any(), any(), any(), any(), any()); + } catch (SQLException e) { + Assertions.fail( + "Encountered exception while stubbing MonitorServiceImpl#getConnectionService: " + e.getMessage()); + } } @AfterEach void tearDown() throws Exception { closeable.close(); - monitorService.releaseResources(); + spyMonitorService.releaseResources(); } @Test public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { - monitorService.registerMonitorTypeIfAbsent( + spyMonitorService.registerMonitorTypeIfAbsent( NoOpMonitor.class, TimeUnit.MINUTES.toNanos(1), TimeUnit.MINUTES.toNanos(1), @@ -75,20 +87,20 @@ public void testMonitorError_monitorReCreated() throws SQLException, Interrupted null ); String key = "testMonitor"; - NoOpMonitor monitor = monitorService.runIfAbsent( + NoOpMonitor monitor = spyMonitorService.runIfAbsent( NoOpMonitor.class, key, - storageService, - telemetryFactory, + mockStorageService, + mockTelemetryFactory, "jdbc:postgresql://somehost/somedb", "someProtocol", - targetDriverDialect, - dbDialect, + mockTargetDriverDialect, + mockDbDialect, new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(monitorService, 30) + (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) ); - Monitor storedMonitor = monitorService.get(NoOpMonitor.class, key); + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); assertNotNull(storedMonitor); assertEquals(monitor, storedMonitor); // need to wait to give time for the monitor executor to start the monitor thread. @@ -96,11 +108,11 @@ public void testMonitorError_monitorReCreated() throws SQLException, Interrupted assertEquals(MonitorState.RUNNING, monitor.getState()); monitor.state.set(MonitorState.ERROR); - monitorService.checkMonitors(); + spyMonitorService.checkMonitors(); assertEquals(MonitorState.STOPPED, monitor.getState()); - Monitor newMonitor = monitorService.get(NoOpMonitor.class, key); + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); assertNotNull(newMonitor); assertNotEquals(monitor, newMonitor); // need to wait to give time for the monitor executor to start the monitor thread. @@ -110,7 +122,7 @@ public void testMonitorError_monitorReCreated() throws SQLException, Interrupted @Test public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { - monitorService.registerMonitorTypeIfAbsent( + spyMonitorService.registerMonitorTypeIfAbsent( NoOpMonitor.class, TimeUnit.MINUTES.toNanos(1), 1, // heartbeat times out immediately @@ -118,20 +130,20 @@ public void testMonitorStuck_monitorReCreated() throws SQLException, Interrupted null ); String key = "testMonitor"; - NoOpMonitor monitor = monitorService.runIfAbsent( + NoOpMonitor monitor = spyMonitorService.runIfAbsent( NoOpMonitor.class, key, - storageService, - telemetryFactory, + mockStorageService, + mockTelemetryFactory, "jdbc:postgresql://somehost/somedb", "someProtocol", - targetDriverDialect, - dbDialect, + mockTargetDriverDialect, + mockDbDialect, new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(monitorService, 30) + (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) ); - Monitor storedMonitor = monitorService.get(NoOpMonitor.class, key); + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); assertNotNull(storedMonitor); assertEquals(monitor, storedMonitor); // need to wait to give time for the monitor executor to start the monitor thread. @@ -139,11 +151,11 @@ public void testMonitorStuck_monitorReCreated() throws SQLException, Interrupted assertEquals(MonitorState.RUNNING, monitor.getState()); // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. - monitorService.checkMonitors(); + spyMonitorService.checkMonitors(); assertEquals(MonitorState.STOPPED, monitor.getState()); - Monitor newMonitor = monitorService.get(NoOpMonitor.class, key); + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); assertNotNull(newMonitor); assertNotEquals(monitor, newMonitor); // need to wait to give time for the monitor executor to start the monitor thread. @@ -153,7 +165,7 @@ public void testMonitorStuck_monitorReCreated() throws SQLException, Interrupted @Test public void testMonitorExpired() throws SQLException, InterruptedException { - monitorService.registerMonitorTypeIfAbsent( + spyMonitorService.registerMonitorTypeIfAbsent( NoOpMonitor.class, TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms TimeUnit.MINUTES.toNanos(1), @@ -163,20 +175,20 @@ public void testMonitorExpired() throws SQLException, InterruptedException { null ); String key = "testMonitor"; - NoOpMonitor monitor = monitorService.runIfAbsent( + NoOpMonitor monitor = spyMonitorService.runIfAbsent( NoOpMonitor.class, key, - storageService, - telemetryFactory, + mockStorageService, + mockTelemetryFactory, "jdbc:postgresql://somehost/somedb", "someProtocol", - targetDriverDialect, - dbDialect, + mockTargetDriverDialect, + mockDbDialect, new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(monitorService, 30) + (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) ); - Monitor storedMonitor = monitorService.get(NoOpMonitor.class, key); + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); assertNotNull(storedMonitor); assertEquals(monitor, storedMonitor); // need to wait to give time for the monitor executor to start the monitor thread. @@ -184,36 +196,36 @@ public void testMonitorExpired() throws SQLException, InterruptedException { assertEquals(MonitorState.RUNNING, monitor.getState()); // checkMonitors() should detect the expiration timeout and stop/remove the monitor. - monitorService.checkMonitors(); + spyMonitorService.checkMonitors(); assertEquals(MonitorState.STOPPED, monitor.getState()); - Monitor newMonitor = monitorService.get(NoOpMonitor.class, key); + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); // monitor should have been removed when checkMonitors() was called. assertNull(newMonitor); } @Test public void testMonitorMismatch() { - assertThrows(IllegalStateException.class, () -> monitorService.runIfAbsent( + assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( CustomEndpointMonitorImpl.class, "testMonitor", - storageService, - telemetryFactory, + mockStorageService, + mockTelemetryFactory, "jdbc:postgresql://somehost/somedb", "someProtocol", - targetDriverDialect, - dbDialect, + mockTargetDriverDialect, + mockDbDialect, new Properties(), // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor // service should detect this and throw an exception. - (connectionService, pluginService) -> new NoOpMonitor(monitorService, 30) + (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) )); } @Test public void testRemove() throws SQLException, InterruptedException { - monitorService.registerMonitorTypeIfAbsent( + spyMonitorService.registerMonitorTypeIfAbsent( NoOpMonitor.class, TimeUnit.MINUTES.toNanos(1), TimeUnit.MINUTES.toNanos(1), @@ -224,30 +236,30 @@ public void testRemove() throws SQLException, InterruptedException { ); String key = "testMonitor"; - NoOpMonitor monitor = monitorService.runIfAbsent( + NoOpMonitor monitor = spyMonitorService.runIfAbsent( NoOpMonitor.class, key, - storageService, - telemetryFactory, + mockStorageService, + mockTelemetryFactory, "jdbc:postgresql://somehost/somedb", "someProtocol", - targetDriverDialect, - dbDialect, + mockTargetDriverDialect, + mockDbDialect, new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(monitorService, 30) + (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) ); assertNotNull(monitor); // need to wait to give time for the monitor executor to start the monitor thread. TimeUnit.MILLISECONDS.sleep(250); - Monitor removedMonitor = monitorService.remove(NoOpMonitor.class, key); + Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); assertEquals(monitor, removedMonitor); assertEquals(MonitorState.RUNNING, monitor.getState()); } @Test public void testStopAndRemove() throws SQLException, InterruptedException { - monitorService.registerMonitorTypeIfAbsent( + spyMonitorService.registerMonitorTypeIfAbsent( NoOpMonitor.class, TimeUnit.MINUTES.toNanos(1), TimeUnit.MINUTES.toNanos(1), @@ -258,24 +270,24 @@ public void testStopAndRemove() throws SQLException, InterruptedException { ); String key = "testMonitor"; - NoOpMonitor monitor = monitorService.runIfAbsent( + NoOpMonitor monitor = spyMonitorService.runIfAbsent( NoOpMonitor.class, key, - storageService, - telemetryFactory, + mockStorageService, + mockTelemetryFactory, "jdbc:postgresql://somehost/somedb", "someProtocol", - targetDriverDialect, - dbDialect, + mockTargetDriverDialect, + mockDbDialect, new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(monitorService, 30) + (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) ); assertNotNull(monitor); // need to wait to give time for the monitor executor to start the monitor thread. TimeUnit.MILLISECONDS.sleep(250); - monitorService.stopAndRemove(NoOpMonitor.class, key); - assertNull(monitorService.get(NoOpMonitor.class, key)); + spyMonitorService.stopAndRemove(NoOpMonitor.class, key); + assertNull(spyMonitorService.get(NoOpMonitor.class, key)); assertEquals(MonitorState.STOPPED, monitor.getState()); } From 9928517af0e820b76db1923d31367d09a7fa5ead Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 26 Aug 2025 09:00:16 -0700 Subject: [PATCH 18/54] Fix failover1 integration test --- .../jdbc/plugin/failover/FailoverConnectionPlugin.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 8bbdbd5fd..19eb59a48 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -758,6 +758,14 @@ protected void failoverWriter() throws SQLException { this.failoverWriterTriggeredCounter.inc(); } + // The writer failover handler uses the reader failover handler, so we need to make sure it has been instantiated. + if (this.readerFailoverHandler == null) { + if (this.readerFailoverHandlerSupplier == null) { + throw new SQLException(Messages.get("Failover.nullReaderFailoverHandlerSupplier")); + } + this.readerFailoverHandler = this.readerFailoverHandlerSupplier.apply(this.connectionService); + } + if (this.writerFailoverHandler == null) { if (this.writerFailoverHandlerSupplier == null) { throw new SQLException(Messages.get("Failover.nullWriterFailoverHandlerSupplier")); From 6a338f2c51436028ccee113a748c0b4b57037093 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 26 Aug 2025 09:14:32 -0700 Subject: [PATCH 19/54] testServerFailoverWithIdleConnections uses the correct failover plugin --- .../src/test/java/integration/container/tests/FailoverTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wrapper/src/test/java/integration/container/tests/FailoverTest.java b/wrapper/src/test/java/integration/container/tests/FailoverTest.java index 1c9359fd5..347264561 100644 --- a/wrapper/src/test/java/integration/container/tests/FailoverTest.java +++ b/wrapper/src/test/java/integration/container/tests/FailoverTest.java @@ -328,7 +328,7 @@ public void testServerFailoverWithIdleConnections() throws SQLException, Interru TestEnvironment.getCurrent().getInfo().getProxyDatabaseInfo().getClusterEndpointPort(); final Properties props = initDefaultProxiedProps(); - props.setProperty(PropertyDefinition.PLUGINS.name, "auroraConnectionTracker,failover"); + props.setProperty(PropertyDefinition.PLUGINS.name, "auroraConnectionTracker," + getFailoverPlugin()); for (int i = 0; i < IDLE_CONNECTIONS_NUM; i++) { // Keep references to 5 idle connections created using the cluster endpoints. From faa09158d0bba9b6f833a16c38b32e9670b32512 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 26 Aug 2025 10:06:01 -0700 Subject: [PATCH 20/54] Fix javadocs --- .../ClusterAwareReaderFailoverHandler.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 8f3ab2f62..e64484959 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -76,8 +76,8 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler /** * ClusterAwareReaderFailoverHandler constructor. * - * @param servicesContainer A provider for creating new connections. - * @param props The initial connection properties to copy over to the new reader. + * @param servicesContainer the service container for the services required by this class. + * @param props the initial connection properties to copy over to the new reader. */ public ClusterAwareReaderFailoverHandler( final FullServicesContainer servicesContainer, @@ -95,12 +95,12 @@ public ClusterAwareReaderFailoverHandler( /** * ClusterAwareReaderFailoverHandler constructor. * - * @param servicesContainer A provider for creating new connections. - * @param props The initial connection properties to copy over to the new reader. - * @param maxFailoverTimeoutMs Maximum allowed time for the entire reader failover process. - * @param timeoutMs Maximum allowed time in milliseconds for each reader connection attempt during - * the reader failover process. - * @param isStrictReaderRequired When true, it disables adding a writer to a list of nodes to connect + * @param servicesContainer the service container for the services required by this class. + * @param props the initial connection properties to copy over to the new reader. + * @param maxFailoverTimeoutMs maximum allowed time for the entire reader failover process. + * @param timeoutMs maximum allowed time in milliseconds for each reader connection attempt during + * the reader failover process. + * @param isStrictReaderRequired when true, it disables adding a writer to a list of nodes to connect */ public ClusterAwareReaderFailoverHandler( final FullServicesContainer servicesContainer, From 23e177af7614e4ab38ef8f646d2f0b6b9f2f19bb Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 26 Aug 2025 09:59:14 -0700 Subject: [PATCH 21/54] wip --- .../jdbc/util/ServiceContainerUtility.java | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java new file mode 100644 index 000000000..fb4bc4d70 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.util; + +import java.util.concurrent.locks.ReentrantLock; + +public class ServiceContainerUtility { + private static volatile ServiceContainerUtility instance; + private static final ReentrantLock initLock = new ReentrantLock(); + + private ServiceContainerUtility() { + if (instance != null) { + throw new IllegalStateException("ServiceContainerUtility singleton instance already exists."); + } + } + + public static ServiceContainerUtility getInstance() { + if (instance != null) { + return instance; + } + + initLock.lock(); + try { + if (instance == null) { + instance = new ServiceContainerUtility(); + } + } finally { + initLock.unlock(); + } + + return instance; + } + + public static FullServicesContainer createServiceContainer() { + + } +} From e1af3eb2757e3e0223c8863cbb368e093dc8018d Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Thu, 28 Aug 2025 18:12:54 -0700 Subject: [PATCH 22/54] Add extra logging to debug IT failure --- .../plugin/readwritesplitting/ReadWriteSplittingPlugin.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index 7a2955843..35d768211 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -422,11 +422,14 @@ private void switchToReaderConnection(final List hosts) if (this.readerHostSpec != null && !hosts.contains(this.readerHostSpec)) { // The old reader cannot be used anymore because it is no longer in the list of allowed hosts. + LOGGER.finest(String.format("ASDF old reader cannot be used, reader host spec: %s", this.readerHostSpec)); + LOGGER.finest(Utils.logTopology(hosts, "ASDF: ")); closeConnectionIfIdle(this.readerConnection); } this.inReadWriteSplit = true; if (!isConnectionUsable(this.readerConnection)) { + LOGGER.finest(String.format("ASDF reader connection not usable, value: %s, isClosed: %s", this.readerConnection, this.readerConnection == null ? "" : this.readerConnection.isClosed())); initializeReaderConnection(hosts); } else { try { From c077ee218977cfbb97f432e973949fd2b1f5937d Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 29 Aug 2025 09:19:45 -0700 Subject: [PATCH 23/54] Fix bug where cached reader connection was incorrectly closed --- .../readwritesplitting/ReadWriteSplittingPlugin.java | 11 ++++++----- .../aws_advanced_jdbc_wrapper_messages.properties | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index 35d768211..90b170f98 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -420,16 +420,17 @@ private void switchToReaderConnection(final List hosts) return; } - if (this.readerHostSpec != null && !hosts.contains(this.readerHostSpec)) { - // The old reader cannot be used anymore because it is no longer in the list of allowed hosts. - LOGGER.finest(String.format("ASDF old reader cannot be used, reader host spec: %s", this.readerHostSpec)); - LOGGER.finest(Utils.logTopology(hosts, "ASDF: ")); + if (this.readerHostSpec != null && !Utils.containsUrl(hosts, this.readerHostSpec.getUrl())) { + // The previous reader cannot be used anymore because it is no longer in the list of allowed hosts. + LOGGER.finest( + Messages.get( + "ReadWriteSplittingPlugin.previousReaderNotAllowed", + new Object[] {this.readerHostSpec, Utils.logTopology(hosts, "")})); closeConnectionIfIdle(this.readerConnection); } this.inReadWriteSplit = true; if (!isConnectionUsable(this.readerConnection)) { - LOGGER.finest(String.format("ASDF reader connection not usable, value: %s, isClosed: %s", this.readerConnection, this.readerConnection == null ? "" : this.readerConnection.isClosed())); initializeReaderConnection(hosts); } else { try { diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index 6bb12413d..1776fcced 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -316,6 +316,7 @@ ReadWriteSplittingPlugin.successfullyConnectedToReader=Successfully connected to ReadWriteSplittingPlugin.failedToConnectToReader=Failed to connect to reader host: ''{0}'' ReadWriteSplittingPlugin.unsupportedHostSpecSelectorStrategy=Unsupported host selection strategy ''{0}'' specified in plugin configuration parameter ''readerHostSelectorStrategy''. Please visit the Read/Write Splitting Plugin documentation for all supported strategies. ReadWriteSplittingPlugin.errorVerifyingInitialHostSpecRole=An error occurred while obtaining the connected host's role. This could occur if the connection is broken or if you are not connected to an Aurora database. +ReadWriteSplittingPlugin.previousReaderNotAllowed=The previous reader connection cannot be used because it is no longer in the list of allowed hosts. Previous reader: {0}. Allowed hosts: {1} SAMLCredentialsProviderFactory.getSamlAssertionFailed=Failed to get SAML Assertion due to exception: ''{0}'' SamlAuthPlugin.javaStsSdkNotInClasspath=Required dependency 'AWS Java SDK for AWS Secret Token Service' is not on the classpath. From 73cf143cf1da8f55b77421472d731512eb8b218b Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 29 Aug 2025 13:43:38 -0700 Subject: [PATCH 24/54] wip --- .../ClusterAwareReaderFailoverHandler.java | 2 - .../failover/FailoverConnectionPlugin.java | 6 +-- .../jdbc/util/ServiceContainerUtility.java | 47 ++++++++++++++++++- 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 8f3ab2f62..0304252d3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -81,11 +81,9 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler */ public ClusterAwareReaderFailoverHandler( final FullServicesContainer servicesContainer, - final ConnectionService connectionService, final Properties props) { this( servicesContainer, - connectionService, props, DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 19eb59a48..b0acf9de0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -316,18 +316,16 @@ public void initHostProvider( initHostProvider( hostListProviderService, initHostProviderFunc, - (connectionService) -> + () -> new ClusterAwareReaderFailoverHandler( this.servicesContainer, - connectionService, this.properties, this.failoverTimeoutMsSetting, this.failoverReaderConnectTimeoutMsSetting, this.failoverMode == FailoverMode.STRICT_READER), - (connectionService) -> + () -> new ClusterAwareWriterFailoverHandler( this.servicesContainer, - connectionService, this.readerFailoverHandler, this.properties, this.failoverTimeoutMsSetting, diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java index fb4bc4d70..e335b5c5f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java @@ -16,7 +16,17 @@ package software.amazon.jdbc.util; +import java.sql.SQLException; +import java.util.Properties; import java.util.concurrent.locks.ReentrantLock; +import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.PartialPluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.monitoring.MonitorService; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; public class ServiceContainerUtility { private static volatile ServiceContainerUtility instance; @@ -45,7 +55,42 @@ public static ServiceContainerUtility getInstance() { return instance; } - public static FullServicesContainer createServiceContainer() { + public static FullServicesContainer createServiceContainer( + StorageService storageService, + MonitorService monitorService, + TelemetryFactory telemetryFactory, + ConnectionProvider connectionProvider, + String originalUrl, + String targetDriverProtocol, + TargetDriverDialect driverDialect, + Dialect dbDialect, + Properties props) throws SQLException { + FullServicesContainer + servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); + ConnectionPluginManager pluginManager = new ConnectionPluginManager( + connectionProvider, + null, + null, + telemetryFactory); + servicesContainer.setConnectionPluginManager(pluginManager); + PartialPluginService partialPluginService = new PartialPluginService( + servicesContainer, + props, + originalUrl, + targetDriverProtocol, + driverDialect, + dbDialect + ); + + pluginManager.init(servicesContainer, props, partialPluginService, null); + return new FullServicesContainerImpl( + storageService, + monitorService, + telemetryFactory, + pluginManager, + partialPluginService, + partialPluginService, + partialPluginService); } } From 99de6eb235a66040b1ff4dba2314ac1f85feedb1 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 29 Aug 2025 13:49:54 -0700 Subject: [PATCH 25/54] Fix javadoc --- .../plugin/failover/ClusterAwareReaderFailoverHandler.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index e64484959..35f16b35f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -77,6 +77,7 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler * ClusterAwareReaderFailoverHandler constructor. * * @param servicesContainer the service container for the services required by this class. + * @param connectionService the service to use to create new connections during failover. * @param props the initial connection properties to copy over to the new reader. */ public ClusterAwareReaderFailoverHandler( @@ -96,6 +97,7 @@ public ClusterAwareReaderFailoverHandler( * ClusterAwareReaderFailoverHandler constructor. * * @param servicesContainer the service container for the services required by this class. + * @param connectionService the service to use to create new connections during failover. * @param props the initial connection properties to copy over to the new reader. * @param maxFailoverTimeoutMs maximum allowed time for the entire reader failover process. * @param timeoutMs maximum allowed time in milliseconds for each reader connection attempt during @@ -292,7 +294,7 @@ public List getReaderHostsByPriority(final List hosts) { boolean shouldIncludeWriter = numOfReaders == 0 || this.pluginService.getDialect().getFailoverRestrictions() - .contains(FailoverRestriction.ENABLE_WRITER_IN_TASK_B); + .contains(FailoverRestriction.ENABLE_WRITER_IN_TASK_B); if (shouldIncludeWriter) { hostsByPriority.add(writerHost); } From 9311447abce720fa3bc8c11f9a1c053d87ff88e5 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 29 Aug 2025 16:07:22 -0700 Subject: [PATCH 26/54] wip --- .../ClusterTopologyMonitorImpl.java | 47 +- .../MonitoringRdsHostListProvider.java | 9 +- .../ClusterAwareReaderFailoverHandler.java | 58 +- .../ClusterAwareWriterFailoverHandler.java | 55 +- .../failover/FailoverConnectionPlugin.java | 43 +- .../jdbc/util/ServiceContainerUtility.java | 9 +- .../util/monitoring/MonitorInitializer.java | 7 +- .../util/monitoring/MonitorServiceImpl.java | 43 +- ...ClusterAwareReaderFailoverHandlerTest.java | 802 +++++++++--------- ...ClusterAwareWriterFailoverHandlerTest.java | 746 ++++++++-------- .../monitoring/MonitorServiceImplTest.java | 612 ++++++------- 11 files changed, 1193 insertions(+), 1238 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java index 4e7b62de1..e61cb6a60 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java @@ -49,14 +49,17 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.hostlistprovider.Topology; import software.amazon.jdbc.util.ExecutorFactory; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.ServiceContainerUtility; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.SynchronousExecutor; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.AbstractMonitor; +import software.amazon.jdbc.util.monitoring.Monitor; import software.amazon.jdbc.util.storage.StorageService; public class ClusterTopologyMonitorImpl extends AbstractMonitor implements ClusterTopologyMonitor { @@ -80,15 +83,13 @@ public class ClusterTopologyMonitorImpl extends AbstractMonitor implements Clust protected final long refreshRateNano; protected final long highRefreshRateNano; + protected final FullServicesContainer servicesContainer; protected final Properties properties; protected final Properties monitoringProperties; protected final HostSpec initialHostSpec; - protected final StorageService storageService; - protected final ConnectionService connectionService; protected final String topologyQuery; protected final String nodeIdQuery; protected final String writerTopologyQuery; - protected final HostListProviderService hostListProviderService; protected final HostSpec clusterInstanceTemplate; protected String clusterId; @@ -109,12 +110,10 @@ public class ClusterTopologyMonitorImpl extends AbstractMonitor implements Clust protected final AtomicReference> nodeThreadsLatestTopology = new AtomicReference<>(null); public ClusterTopologyMonitorImpl( + final FullServicesContainer servicesContainer, final String clusterId, - final StorageService storageService, - final ConnectionService connectionService, final HostSpec initialHostSpec, final Properties properties, - final HostListProviderService hostListProviderService, final HostSpec clusterInstanceTemplate, final long refreshRateNano, final long highRefreshRateNano, @@ -124,9 +123,7 @@ public ClusterTopologyMonitorImpl( super(monitorTerminationTimeoutSec); this.clusterId = clusterId; - this.storageService = storageService; - this.connectionService = connectionService; - this.hostListProviderService = hostListProviderService; + this.servicesContainer = servicesContainer; this.initialHostSpec = initialHostSpec; this.clusterInstanceTemplate = clusterInstanceTemplate; this.properties = properties; @@ -251,7 +248,7 @@ protected List waitTillTopologyGetsUpdated(final long timeoutMs) throw } private List getStoredHosts() { - Topology topology = storageService.get(Topology.class, this.clusterId); + Topology topology = this.servicesContainer.getStorageService().get(Topology.class, this.clusterId); return topology == null ? null : topology.getHosts(); } @@ -480,8 +477,23 @@ protected boolean isInPanicMode() { || !this.isVerifiedWriterConnection; } - protected Runnable getNodeMonitoringWorker(final HostSpec hostSpec, final @Nullable HostSpec writerHostSpec) { - return new NodeMonitoringWorker(this, hostSpec, writerHostSpec); + protected Runnable getNodeMonitoringWorker( + final HostSpec hostSpec, final @Nullable HostSpec writerHostSpec) + throws SQLException { + return new NodeMonitoringWorker(this.getNewServicesContainer(), this, hostSpec, writerHostSpec); + } + + protected FullServicesContainer getNewServicesContainer() throws SQLException { + return ServiceContainerUtility.createServiceContainer( + this.servicesContainer.getStorageService(), + this.servicesContainer.getMonitorService(), + this.servicesContainer.getTelemetryFactory(), + this.servicesContainer.getPluginService().getOriginalUrl(), + this.servicesContainer.getPluginService().getDriverProtocol(), + this.servicesContainer.getPluginService().getTargetDriverDialect(), + this.servicesContainer.getPluginService().getDialect(), + this.properties + ); } protected List openAnyConnectionAndUpdateTopology() { @@ -492,7 +504,7 @@ protected List openAnyConnectionAndUpdateTopology() { // open a new connection try { - conn = this.connectionService.open(this.initialHostSpec, this.monitoringProperties); + conn = this.servicesContainer.getPluginService().forceConnect(this.initialHostSpec, this.monitoringProperties); } catch (SQLException ex) { // can't connect return null; @@ -625,7 +637,7 @@ protected void delay(boolean useHighRefreshRate) throws InterruptedException { protected void updateTopologyCache(final @NonNull List hosts) { synchronized (this.requestToUpdateTopology) { - storageService.set(this.clusterId, new Topology(hosts)); + this.servicesContainer.getStorageService().set(this.clusterId, new Topology(hosts)); synchronized (this.topologyUpdated) { this.requestToUpdateTopology.set(false); @@ -769,7 +781,7 @@ protected HostSpec createHost( ? this.clusterInstanceTemplate.getPort() : this.initialHostSpec.getPort(); - final HostSpec hostSpec = this.hostListProviderService.getHostSpecBuilder() + final HostSpec hostSpec = this.servicesContainer.getHostListProviderService().getHostSpecBuilder() .host(endpoint) .port(port) .role(isWriter ? HostRole.WRITER : HostRole.READER) @@ -791,16 +803,19 @@ private static class NodeMonitoringWorker implements Runnable { private static final Logger LOGGER = Logger.getLogger(NodeMonitoringWorker.class.getName()); + protected final FullServicesContainer servicesContainer; protected final ClusterTopologyMonitorImpl monitor; protected final HostSpec hostSpec; protected final @Nullable HostSpec writerHostSpec; protected boolean writerChanged = false; public NodeMonitoringWorker( + final FullServicesContainer servicesContainer, final ClusterTopologyMonitorImpl monitor, final HostSpec hostSpec, final @Nullable HostSpec writerHostSpec ) { + this.servicesContainer = servicesContainer; this.monitor = monitor; this.hostSpec = hostSpec; this.writerHostSpec = writerHostSpec; @@ -818,7 +833,7 @@ public void run() { if (connection == null) { try { - connection = this.monitor.connectionService.open( + connection = this.servicesContainer.getPluginService().forceConnect( hostSpec, this.monitor.monitoringProperties); } catch (SQLException ex) { // A problem occurred while connecting. We will try again on the next iteration. diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java index 29aae6ddd..2883100e0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java @@ -32,7 +32,6 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.Topology; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; @@ -92,13 +91,11 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect(), this.properties, - (ConnectionService connectionService, PluginService monitorPluginService) -> new ClusterTopologyMonitorImpl( + (servicesContainer) -> new ClusterTopologyMonitorImpl( + this.servicesContainer, this.clusterId, - this.servicesContainer.getStorageService(), - connectionService, this.initialHostSpec, this.properties, - this.servicesContainer.getHostListProviderService(), this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, @@ -138,7 +135,7 @@ protected void clusterIdChanged(final String oldClusterId) throws SQLException { this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect(), this.properties, - (connectionService, pluginService) -> existingMonitor); + (servicesContainer) -> existingMonitor); assert monitorService.get(ClusterTopologyMonitorImpl.class, this.clusterId) == existingMonitor; existingMonitor.setClusterId(this.clusterId); monitorService.remove(ClusterTopologyMonitorImpl.class, oldClusterId); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 6045e00ba..1ad010af8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -36,15 +36,14 @@ import java.util.logging.Logger; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; +import software.amazon.jdbc.util.ServiceContainerUtility; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionService; /** * An implementation of ReaderFailoverHandler. @@ -66,7 +65,6 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler protected static final int DEFAULT_READER_CONNECT_TIMEOUT = 30000; // 30 sec protected final Map hostAvailabilityMap = new ConcurrentHashMap<>(); protected final FullServicesContainer servicesContainer; - protected final ConnectionService connectionService; protected final PluginService pluginService; protected Properties props; protected int maxFailoverTimeoutMs; @@ -77,7 +75,6 @@ public class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler * ClusterAwareReaderFailoverHandler constructor. * * @param servicesContainer the service container for the services required by this class. - * @param connectionService the service to use to create new connections during failover. * @param props the initial connection properties to copy over to the new reader. */ public ClusterAwareReaderFailoverHandler( @@ -95,7 +92,6 @@ public ClusterAwareReaderFailoverHandler( * ClusterAwareReaderFailoverHandler constructor. * * @param servicesContainer the service container for the services required by this class. - * @param connectionService the service to use to create new connections during failover. * @param props the initial connection properties to copy over to the new reader. * @param maxFailoverTimeoutMs maximum allowed time for the entire reader failover process. * @param timeoutMs maximum allowed time in milliseconds for each reader connection attempt during @@ -104,13 +100,11 @@ public ClusterAwareReaderFailoverHandler( */ public ClusterAwareReaderFailoverHandler( final FullServicesContainer servicesContainer, - final ConnectionService connectionService, final Properties props, final int maxFailoverTimeoutMs, final int timeoutMs, final boolean isStrictReaderRequired) { this.servicesContainer = servicesContainer; - this.connectionService = connectionService; this.pluginService = servicesContainer.getPluginService(); this.props = props; this.maxFailoverTimeoutMs = maxFailoverTimeoutMs; @@ -304,15 +298,16 @@ private ReaderFailoverResult getConnectionFromHostGroup(final List hos final ExecutorService executor = ExecutorFactory.newFixedThreadPool(2, "failover"); final CompletionService completionService = new ExecutorCompletionService<>(executor); - // The ConnectionAttemptTask threads should have their own plugin services since they execute concurrently and - // PluginService was not designed to be thread-safe. - List pluginServices = Arrays.asList(getNewPluginService(), getNewPluginService()); + // The ConnectionAttemptTask threads should each have their own services container since they execute concurrently + // and PluginService was not designed to be thread-safe. + List servicesContainers = + Arrays.asList(getNewServicesContainer(), getNewServicesContainer()); try { for (int i = 0; i < hosts.size(); i += 2) { // submit connection attempt tasks in batches of 2 final ReaderFailoverResult result = - getResultFromNextTaskBatch(hosts, executor, completionService, pluginServices, i); + getResultFromNextTaskBatch(hosts, executor, completionService, servicesContainers, i); if (result.isConnected() || result.getException() != null) { return result; } @@ -335,14 +330,13 @@ private ReaderFailoverResult getResultFromNextTaskBatch( final List hosts, final ExecutorService executor, final CompletionService completionService, - final List pluginServices, + final List servicesContainers, final int i) throws SQLException { ReaderFailoverResult result; final int numTasks = i + 1 < hosts.size() ? 2 : 1; completionService.submit( new ConnectionAttemptTask( - this.connectionService, - pluginServices.get(0), + servicesContainers.get(0), this.hostAvailabilityMap, hosts.get(i), this.props, @@ -350,8 +344,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( if (numTasks == 2) { completionService.submit( new ConnectionAttemptTask( - this.connectionService, - pluginServices.get(1), + servicesContainers.get(1), this.hostAvailabilityMap, hosts.get(i + 1), this.props, @@ -372,6 +365,19 @@ private ReaderFailoverResult getResultFromNextTaskBatch( return new ReaderFailoverResult(null, null, false); } + protected FullServicesContainer getNewServicesContainer() throws SQLException { + return ServiceContainerUtility.createServiceContainer( + this.servicesContainer.getStorageService(), + this.servicesContainer.getMonitorService(), + this.servicesContainer.getTelemetryFactory(), + this.pluginService.getOriginalUrl(), + this.pluginService.getDriverProtocol(), + this.pluginService.getTargetDriverDialect(), + this.pluginService.getDialect(), + this.props + ); + } + private ReaderFailoverResult getNextResult(final CompletionService service) throws SQLException { try { @@ -394,19 +400,7 @@ private ReaderFailoverResult getNextResult(final CompletionService { - private final ConnectionService connectionService; private final PluginService pluginService; private final Map availabilityMap; private final HostSpec newHost; @@ -414,14 +408,12 @@ private static class ConnectionAttemptTask implements Callable availabilityMap, final HostSpec newHost, final Properties props, final boolean isStrictReaderRequired) { - this.connectionService = connectionService; - this.pluginService = pluginService; + this.pluginService = servicesContainer.getPluginService(); this.availabilityMap = availabilityMap; this.newHost = newHost; this.props = props; @@ -442,7 +434,7 @@ public ReaderFailoverResult call() { final Properties copy = new Properties(); copy.putAll(props); - final Connection conn = this.connectionService.open(this.newHost, copy); + final Connection conn = this.pluginService.forceConnect(this.newHost, copy); this.availabilityMap.put(this.newHost.getHost(), HostAvailability.AVAILABLE); if (this.isStrictReaderRequired) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 1921f3238..28c614141 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -34,15 +34,14 @@ import java.util.logging.Logger; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; +import software.amazon.jdbc.util.ServiceContainerUtility; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionService; /** * An implementation of WriterFailoverHandler. @@ -59,7 +58,6 @@ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler protected final Properties initialConnectionProps; protected final FullServicesContainer servicesContainer; - protected final ConnectionService connectionService; protected final PluginService pluginService; protected final ReaderFailoverHandler readerFailoverHandler; protected final Map hostAvailabilityMap = new ConcurrentHashMap<>(); @@ -69,11 +67,9 @@ public class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, - final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps) { this.servicesContainer = servicesContainer; - this.connectionService = connectionService; this.pluginService = servicesContainer.getPluginService(); this.readerFailoverHandler = readerFailoverHandler; this.initialConnectionProps = initialConnectionProps; @@ -81,7 +77,6 @@ public ClusterAwareWriterFailoverHandler( public ClusterAwareWriterFailoverHandler( final FullServicesContainer servicesContainer, - final ConnectionService connectionService, final ReaderFailoverHandler readerFailoverHandler, final Properties initialConnectionProps, final int failoverTimeoutMs, @@ -89,7 +84,6 @@ public ClusterAwareWriterFailoverHandler( final int reconnectWriterIntervalMs) { this( servicesContainer, - connectionService, readerFailoverHandler, initialConnectionProps); this.maxFailoverTimeoutMs = failoverTimeoutMs; @@ -103,8 +97,7 @@ public Map getHostAvailabilityMap() { } @Override - public WriterFailoverResult failover(final List currentTopology) - throws SQLException { + public WriterFailoverResult failover(final List currentTopology) throws SQLException { if (Utils.isNullOrEmpty(currentTopology)) { LOGGER.severe(() -> Messages.get("ClusterAwareWriterFailoverHandler.failoverCalledWithInvalidTopology")); return DEFAULT_RESULT; @@ -162,13 +155,12 @@ private void submitTasks( final List currentTopology, final ExecutorService executorService, final CompletionService completionService, - final boolean singleTask) { + final boolean singleTask) throws SQLException { final HostSpec writerHost = getWriter(currentTopology); if (!singleTask) { completionService.submit( new ReconnectToWriterHandler( - this.connectionService, - this.getNewPluginService(), + this.getNewServicesContainer(), this.hostAvailabilityMap, writerHost, this.initialConnectionProps, @@ -177,8 +169,7 @@ private void submitTasks( completionService.submit( new WaitForNewWriterHandler( - this.connectionService, - this.getNewPluginService(), + this.getNewServicesContainer(), this.hostAvailabilityMap, this.readerFailoverHandler, writerHost, @@ -189,16 +180,18 @@ private void submitTasks( executorService.shutdown(); } - protected PluginService getNewPluginService() { - // Each task should get its own PluginService since they execute concurrently and PluginService was not designed to - // be thread-safe. - return new PartialPluginService( - this.servicesContainer, - this.initialConnectionProps, + protected FullServicesContainer getNewServicesContainer() throws SQLException { + // Each task should get its own FullServicesContainer since they execute concurrently and PluginService was not + // designed to be thread-safe. + return ServiceContainerUtility.createServiceContainer( + this.servicesContainer.getStorageService(), + this.servicesContainer.getMonitorService(), + this.servicesContainer.getTelemetryFactory(), this.pluginService.getOriginalUrl(), this.pluginService.getDriverProtocol(), this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect() + this.pluginService.getDialect(), + this.initialConnectionProps ); } @@ -275,8 +268,6 @@ private SQLException createInterruptedException(final InterruptedException e) { * Internal class responsible for re-connecting to the current writer (aka TaskA). */ private static class ReconnectToWriterHandler implements Callable { - - private final ConnectionService connectionService; private final PluginService pluginService; private final Map availabilityMap; private final HostSpec originalWriterHost; @@ -284,14 +275,12 @@ private static class ReconnectToWriterHandler implements Callable availabilityMap, final HostSpec originalWriterHost, final Properties props, final int reconnectWriterIntervalMs) { - this.connectionService = connectionService; - this.pluginService = pluginService; + this.pluginService = servicesContainer.getPluginService(); this.availabilityMap = availabilityMap; this.originalWriterHost = originalWriterHost; this.props = props; @@ -315,7 +304,7 @@ public WriterFailoverResult call() { conn.close(); } - conn = this.connectionService.open(this.originalWriterHost, this.props); + conn = this.pluginService.forceConnect(this.originalWriterHost, this.props); this.pluginService.forceRefreshHostList(conn); latestTopology = this.pluginService.getAllHosts(); } catch (final SQLException exception) { @@ -371,8 +360,6 @@ private boolean isCurrentHostWriter(final List latestTopology) { * elected writer (aka TaskB). */ private static class WaitForNewWriterHandler implements Callable { - - private final ConnectionService connectionService; private final PluginService pluginService; private final Map availabilityMap; private final ReaderFailoverHandler readerFailoverHandler; @@ -385,16 +372,14 @@ private static class WaitForNewWriterHandler implements Callable availabilityMap, final ReaderFailoverHandler readerFailoverHandler, final HostSpec originalWriterHost, final Properties props, final int readTopologyIntervalMs, final List currentTopology) { - this.connectionService = connectionService; - this.pluginService = pluginService; + this.pluginService = servicesContainer.getPluginService(); this.availabilityMap = availabilityMap; this.readerFailoverHandler = readerFailoverHandler; this.originalWriterHost = originalWriterHost; @@ -537,7 +522,7 @@ private boolean connectToWriter(final HostSpec writerCandidate) { new Object[] {writerCandidate.getUrl()})); try { // connect to the new writer - this.currentConnection = this.connectionService.open(writerCandidate, this.props); + this.currentConnection = this.pluginService.forceConnect(writerCandidate, this.props); this.availabilityMap.put(writerCandidate.getHost(), HostAvailability.AVAILABLE); return true; } catch (final SQLException exception) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index b0acf9de0..66cdcfb12 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -28,13 +28,11 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; +import java.util.function.Supplier; import java.util.logging.Level; import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.AwsWrapperProperty; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.DriverConnectionProvider; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; @@ -44,7 +42,6 @@ import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.TargetDriverHelper; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsHelper; @@ -56,8 +53,6 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionService; -import software.amazon.jdbc.util.connection.ConnectionServiceImpl; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -99,7 +94,6 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { private final Set subscribedMethods; private final PluginService pluginService; private final FullServicesContainer servicesContainer; - private ConnectionService connectionService; protected final Properties properties; protected boolean enableFailoverSetting; protected boolean enableConnectFailover; @@ -122,8 +116,8 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { private RdsUrlType rdsUrlType = null; private HostListProviderService hostListProviderService; private final AuroraStaleDnsHelper staleDnsHelper; - private Function writerFailoverHandlerSupplier; - private Function readerFailoverHandlerSupplier; + private Supplier writerFailoverHandlerSupplier; + private Supplier readerFailoverHandlerSupplier; public static final AwsWrapperProperty FAILOVER_CLUSTER_TOPOLOGY_REFRESH_RATE_MS = new AwsWrapperProperty( @@ -336,8 +330,8 @@ public void initHostProvider( void initHostProvider( final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc, - final Function readerFailoverHandlerSupplier, - final Function writerFailoverHandlerSupplier) + final Supplier readerFailoverHandlerSupplier, + final Supplier writerFailoverHandlerSupplier) throws SQLException { this.readerFailoverHandlerSupplier = readerFailoverHandlerSupplier; this.writerFailoverHandlerSupplier = writerFailoverHandlerSupplier; @@ -618,10 +612,6 @@ protected void dealWithIllegalStateException( */ protected void failover(final HostSpec failedHost) throws SQLException { this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); - if (this.connectionService == null) { - this.connectionService = getConnectionService(); - } - if (this.failoverMode == FailoverMode.STRICT_WRITER) { failoverWriter(); } else { @@ -629,23 +619,6 @@ protected void failover(final HostSpec failedHost) throws SQLException { } } - protected ConnectionService getConnectionService() throws SQLException { - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(this.pluginService.getOriginalUrl(), properties); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); - return new ConnectionServiceImpl( - servicesContainer.getStorageService(), - servicesContainer.getMonitorService(), - servicesContainer.getTelemetryFactory(), - defaultConnectionProvider, - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - properties - ); - } - protected void failoverReader(final HostSpec failedHostSpec) throws SQLException { TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext( @@ -658,7 +631,7 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException if (this.readerFailoverHandlerSupplier == null) { throw new SQLException(Messages.get("Failover.nullReaderFailoverHandlerSupplier")); } - this.readerFailoverHandler = this.readerFailoverHandlerSupplier.apply(this.connectionService); + this.readerFailoverHandler = this.readerFailoverHandlerSupplier.get(); } final long failoverStartNano = System.nanoTime(); @@ -761,14 +734,14 @@ protected void failoverWriter() throws SQLException { if (this.readerFailoverHandlerSupplier == null) { throw new SQLException(Messages.get("Failover.nullReaderFailoverHandlerSupplier")); } - this.readerFailoverHandler = this.readerFailoverHandlerSupplier.apply(this.connectionService); + this.readerFailoverHandler = this.readerFailoverHandlerSupplier.get(); } if (this.writerFailoverHandler == null) { if (this.writerFailoverHandlerSupplier == null) { throw new SQLException(Messages.get("Failover.nullWriterFailoverHandlerSupplier")); } - this.writerFailoverHandler = this.writerFailoverHandlerSupplier.apply(this.connectionService); + this.writerFailoverHandler = this.writerFailoverHandlerSupplier.get(); } long failoverStartTimeNano = System.nanoTime(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java index e335b5c5f..e0e57b793 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java @@ -21,7 +21,9 @@ import java.util.concurrent.locks.ReentrantLock; import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.DriverConnectionProvider; import software.amazon.jdbc.PartialPluginService; +import software.amazon.jdbc.TargetDriverHelper; import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.monitoring.MonitorService; @@ -59,16 +61,19 @@ public static FullServicesContainer createServiceContainer( StorageService storageService, MonitorService monitorService, TelemetryFactory telemetryFactory, - ConnectionProvider connectionProvider, String originalUrl, String targetDriverProtocol, TargetDriverDialect driverDialect, Dialect dbDialect, Properties props) throws SQLException { + final TargetDriverHelper helper = new TargetDriverHelper(); + final java.sql.Driver driver = helper.getTargetDriver(originalUrl, props); + final ConnectionProvider connProvider = new DriverConnectionProvider(driver); + FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); ConnectionPluginManager pluginManager = new ConnectionPluginManager( - connectionProvider, + connProvider, null, null, telemetryFactory); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorInitializer.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorInitializer.java index c4f13e6f9..e1cb59a95 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorInitializer.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorInitializer.java @@ -16,10 +16,9 @@ package software.amazon.jdbc.util.monitoring; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.util.connection.ConnectionService; +import software.amazon.jdbc.util.FullServicesContainer; +@FunctionalInterface public interface MonitorInitializer { - - Monitor createMonitor(ConnectionService connectionService, PluginService pluginService); + Monitor createMonitor(FullServicesContainer servicesContainer); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index d45729c72..b7d59f329 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -30,9 +30,6 @@ import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.DriverConnectionProvider; -import software.amazon.jdbc.TargetDriverHelper; import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostlistprovider.Topology; import software.amazon.jdbc.hostlistprovider.monitoring.ClusterTopologyMonitorImpl; @@ -40,10 +37,10 @@ import software.amazon.jdbc.plugin.strategy.fastestresponse.NodeResponseTimeMonitor; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.ExecutorFactory; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; -import software.amazon.jdbc.util.connection.ConnectionService; -import software.amazon.jdbc.util.connection.ConnectionServiceImpl; +import software.amazon.jdbc.util.ServiceContainerUtility; import software.amazon.jdbc.util.events.DataAccessEvent; import software.amazon.jdbc.util.events.Event; import software.amazon.jdbc.util.events.EventPublisher; @@ -198,20 +195,10 @@ public T runIfAbsent( cacheContainer = monitorCaches.computeIfAbsent(monitorClass, k -> supplier.get()); } - final ConnectionService connectionService = - getConnectionService( - storageService, - telemetryFactory, - originalUrl, - driverProtocol, - driverDialect, - dbDialect, - originalProps); - + final FullServicesContainer servicesContainer = getNewServicesContainer( + storageService, telemetryFactory, originalUrl, driverProtocol, driverDialect, dbDialect, originalProps); Monitor monitor = cacheContainer.getCache().computeIfAbsent(key, k -> { - MonitorItem monitorItem = new MonitorItem(() -> initializer.createMonitor( - connectionService, - connectionService.getPluginService())); + MonitorItem monitorItem = new MonitorItem(() -> initializer.createMonitor(servicesContainer)); monitorItem.getMonitor().start(); return monitorItem; }).getMonitor(); @@ -224,23 +211,25 @@ public T runIfAbsent( Messages.get("MonitorServiceImpl.unexpectedMonitorClass", new Object[] {monitorClass, monitor})); } - protected ConnectionService getConnectionService(StorageService storageService, - TelemetryFactory telemetryFactory, String originalUrl, String driverProtocol, TargetDriverDialect driverDialect, - Dialect dbDialect, Properties originalProps) throws SQLException { - TargetDriverHelper helper = new TargetDriverHelper(); - java.sql.Driver driver = helper.getTargetDriver(originalUrl, originalProps); - final ConnectionProvider defaultConnectionProvider = new DriverConnectionProvider(driver); + protected FullServicesContainer getNewServicesContainer( + StorageService storageService, + TelemetryFactory telemetryFactory, + String originalUrl, + String driverProtocol, + TargetDriverDialect driverDialect, + Dialect dbDialect, + Properties originalProps) throws SQLException { final Properties propsCopy = PropertyUtils.copyProperties(originalProps); - return new ConnectionServiceImpl( + return ServiceContainerUtility.createServiceContainer( storageService, this, telemetryFactory, - defaultConnectionProvider, originalUrl, driverProtocol, driverDialect, dbDialect, - propsCopy); + propsCopy + ); } @Override diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java index 966ff6ec4..a6351a849 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java @@ -1,401 +1,401 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.ConnectionPluginManager; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionService; - -class ClusterAwareReaderFailoverHandlerTest { - @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; - @Mock PluginService mockPluginService; - @Mock ConnectionPluginManager mockPluginManager; - @Mock Connection mockConnection; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final List defaultHosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader2").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader3").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader4").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader5").port(1234).role(HostRole.READER).build() - ); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); - when(mockContainer.getPluginService()).thenReturn(mockPluginService); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testFailover() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: successful connection for host at index 4 - final List hosts = defaultHosts; - final int currentHostIndex = 2; - final int successHostIndex = 4; - for (int i = 0; i < hosts.size(); i++) { - if (i != successHostIndex) { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockConnectionService.open(hosts.get(i), properties)) - .thenThrow(exception); - when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); - } else { - when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); - } - } - - when(mockPluginService.getTargetDriverDialect()).thenReturn(null); - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(successHostIndex), result.getHost()); - - final HostSpec successHost = hosts.get(successHostIndex); - final Map availabilityMap = target.getHostAvailabilityMap(); - Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); - assertTrue(unavailableHosts.size() >= 4); - assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); - } - - private Set getHostsWithGivenAvailability( - Map availabilityMap, HostAvailability availability) { - return availabilityMap.entrySet().stream() - .filter((entry) -> availability.equals(entry.getValue())) - .map(Map.Entry::getKey) - .collect(Collectors.toSet()); - } - - @Test - public void testFailover_timeout() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: failure to get reader since process is limited to 5s and each attempt - // to connect takes 20s - final List hosts = defaultHosts; - final int currentHostIndex = 2; - for (HostSpec host : hosts) { - when(mockConnectionService.open(host, properties)) - .thenAnswer((Answer) invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - } - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); - - final long startTimeNano = System.nanoTime(); - final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() { - ClusterAwareReaderFailoverHandler handler = - spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); - doReturn(mockPluginService).when(handler).getNewPluginService(); - return handler; - } - - private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( - int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) { - ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( - mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); - ClusterAwareReaderFailoverHandler spyHandler = spy(handler); - doReturn(mockPluginService).when(spyHandler).getNewPluginService(); - return spyHandler; - } - - @Test - public void testFailover_nullOrEmptyHostList() throws SQLException { - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final HostSpec currentHost = - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); - - ReaderFailoverResult result = target.failover(null, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - final List hosts = new ArrayList<>(); - result = target.failover(hosts, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionSuccess() throws SQLException { - // even number of connection attempts - // first connection attempt to return succeeds, second attempt cancelled - // expected test result: successful connection for host at index 2 - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - final HostSpec slowHost = hosts.get(1); - final HostSpec fastHost = hosts.get(2); - when(mockConnectionService.open(slowHost, properties)) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(2), result.getHost()); - - Map availabilityMap = target.getHostAvailabilityMap(); - assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); - assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); - } - - @Test - public void testGetReader_connectionFailure() throws SQLException { - // odd number of connection attempts - // first connection attempt to return fails - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) - when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionAttemptsTimeout() throws SQLException { - // connection attempts time out before they can succeed - // first connection attempt to return times out - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - when(mockConnectionService.open(any(), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - try { - Thread.sleep(5000); - } catch (InterruptedException exception) { - // ignore - } - return mockConnection; - }); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetHostTuplesByPriority() { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final List hostsByPriority = target.getHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting a writer - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.WRITER) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testGetReaderTuplesByPriority() { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testHostFailoverStrictReaderEnabled() { - - final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(); - final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(); - final List hosts = Arrays.asList(writer, reader); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = - getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); - - // The writer is included because the original writer has likely become a reader. - List expectedHostsByPriority = Arrays.asList(reader, writer); - - List hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. - reader.setAvailability(HostAvailability.NOT_AVAILABLE); - expectedHostsByPriority = Arrays.asList(writer, reader); - - hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Writer node will only be picked if it is the only node in topology; - List expectedWriterHost = Collections.singletonList(writer); - - hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); - assertEquals(expectedWriterHost, hostsByPriority); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.when; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Map; +// import java.util.Properties; +// import java.util.Set; +// import java.util.concurrent.TimeUnit; +// import java.util.stream.Collectors; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.ConnectionPluginManager; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.connection.ConnectionService; +// +// class ClusterAwareReaderFailoverHandlerTest { +// @Mock FullServicesContainer mockContainer; +// @Mock ConnectionService mockConnectionService; +// @Mock PluginService mockPluginService; +// @Mock ConnectionPluginManager mockPluginManager; +// @Mock Connection mockConnection; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final List defaultHosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader2").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader3").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader4").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader5").port(1234).role(HostRole.READER).build() +// ); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); +// when(mockContainer.getPluginService()).thenReturn(mockPluginService); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testFailover() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: successful connection for host at index 4 +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// final int successHostIndex = 4; +// for (int i = 0; i < hosts.size(); i++) { +// if (i != successHostIndex) { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockConnectionService.open(hosts.get(i), properties)) +// .thenThrow(exception); +// when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); +// } else { +// when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); +// } +// } +// +// when(mockPluginService.getTargetDriverDialect()).thenReturn(null); +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(successHostIndex), result.getHost()); +// +// final HostSpec successHost = hosts.get(successHostIndex); +// final Map availabilityMap = target.getHostAvailabilityMap(); +// Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); +// assertTrue(unavailableHosts.size() >= 4); +// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); +// } +// +// private Set getHostsWithGivenAvailability( +// Map availabilityMap, HostAvailability availability) { +// return availabilityMap.entrySet().stream() +// .filter((entry) -> availability.equals(entry.getValue())) +// .map(Map.Entry::getKey) +// .collect(Collectors.toSet()); +// } +// +// @Test +// public void testFailover_timeout() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: failure to get reader since process is limited to 5s and each attempt +// // to connect takes 20s +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// for (HostSpec host : hosts) { +// when(mockConnectionService.open(host, properties)) +// .thenAnswer((Answer) invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// } +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); +// +// final long startTimeNano = System.nanoTime(); +// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() { +// ClusterAwareReaderFailoverHandler handler = +// spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); +// doReturn(mockPluginService).when(handler).getNewPluginService(); +// return handler; +// } +// +// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( +// int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) { +// ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( +// mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); +// ClusterAwareReaderFailoverHandler spyHandler = spy(handler); +// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); +// return spyHandler; +// } +// +// @Test +// public void testFailover_nullOrEmptyHostList() throws SQLException { +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final HostSpec currentHost = +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); +// +// ReaderFailoverResult result = target.failover(null, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// final List hosts = new ArrayList<>(); +// result = target.failover(hosts, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionSuccess() throws SQLException { +// // even number of connection attempts +// // first connection attempt to return succeeds, second attempt cancelled +// // expected test result: successful connection for host at index 2 +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// final HostSpec slowHost = hosts.get(1); +// final HostSpec fastHost = hosts.get(2); +// when(mockConnectionService.open(slowHost, properties)) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(2), result.getHost()); +// +// Map availabilityMap = target.getHostAvailabilityMap(); +// assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); +// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); +// } +// +// @Test +// public void testGetReader_connectionFailure() throws SQLException { +// // odd number of connection attempts +// // first connection attempt to return fails +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) +// when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionAttemptsTimeout() throws SQLException { +// // connection attempts time out before they can succeed +// // first connection attempt to return times out +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// when(mockConnectionService.open(any(), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// try { +// Thread.sleep(5000); +// } catch (InterruptedException exception) { +// // ignore +// } +// return mockConnection; +// }); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetHostTuplesByPriority() { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final List hostsByPriority = target.getHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting a writer +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.WRITER) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testGetReaderTuplesByPriority() { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testHostFailoverStrictReaderEnabled() { +// +// final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(); +// final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(); +// final List hosts = Arrays.asList(writer, reader); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = +// getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); +// +// // The writer is included because the original writer has likely become a reader. +// List expectedHostsByPriority = Arrays.asList(reader, writer); +// +// List hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. +// reader.setAvailability(HostAvailability.NOT_AVAILABLE); +// expectedHostsByPriority = Arrays.asList(writer, reader); +// +// hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Writer node will only be picked if it is the only node in topology; +// List expectedWriterHost = Collections.singletonList(writer); +// +// hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); +// assertEquals(expectedWriterHost, hostsByPriority); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java index 1ad394010..097b68d67 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java @@ -1,373 +1,373 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.refEq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionService; - -class ClusterAwareWriterFailoverHandlerTest { - @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; - @Mock PluginService mockPluginService; - @Mock Connection mockConnection; - @Mock ReaderFailoverHandler mockReaderFailoverHandler; - @Mock Connection mockWriterConnection; - @Mock Connection mockNewWriterConnection; - @Mock Connection mockReaderAConnection; - @Mock Connection mockReaderBConnection; - @Mock Dialect mockDialect; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("new-writer-host").build(); - private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer-host").build(); - private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-a-host").build(); - private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-b-host").build(); - private final List topology = Arrays.asList(writer, readerA, readerB); - private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - when(mockContainer.getPluginService()).thenReturn(mockPluginService); - writer.addAlias("writer-host"); - newWriterHost.addAlias("new-writer-host"); - readerA.addAlias("reader-a-host"); - readerB.addAlias("reader-b-host"); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testReconnectToWriter_taskBReaderException() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockConnection); - - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( - final int failoverTimeoutMs, - final int readTopologyIntervalMs, - final int reconnectWriterIntervalMs) { - ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( - mockContainer, - mockConnectionService, - mockReaderFailoverHandler, - properties, - failoverTimeoutMs, - readTopologyIntervalMs, - reconnectWriterIntervalMs); - - ClusterAwareWriterFailoverHandler spyHandler = spy(handler); - doReturn(mockPluginService).when(spyHandler).getNewPluginService(); - return spyHandler; - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: successfully re-connect to initial writer; return new connection. - * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_SlowReaderA() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return new ReaderFailoverResult(mockReaderAConnection, readerA, true); - }); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes. - * TaskA: successfully re-connect to writer; return new connection. - * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_taskBDefers() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. - * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more - * time than taskB. - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_SlowWriter() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(3, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. - * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_taskADefers() throws SQLException { - when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockNewWriterConnection; - }); - - final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(4, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to failover timeout. - * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_failoverTimeout() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockNewWriterConnection; - }); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - - final long startTimeNano = System.nanoTime(); - final WriterFailoverResult result = target.failover(topology); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to exception. - * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); - when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.ArgumentMatchers.refEq; +// import static org.mockito.Mockito.atLeastOnce; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.Arrays; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.ArgumentMatchers; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.connection.ConnectionService; +// +// class ClusterAwareWriterFailoverHandlerTest { +// @Mock FullServicesContainer mockContainer; +// @Mock ConnectionService mockConnectionService; +// @Mock PluginService mockPluginService; +// @Mock Connection mockConnection; +// @Mock ReaderFailoverHandler mockReaderFailoverHandler; +// @Mock Connection mockWriterConnection; +// @Mock Connection mockNewWriterConnection; +// @Mock Connection mockReaderAConnection; +// @Mock Connection mockReaderBConnection; +// @Mock Dialect mockDialect; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("new-writer-host").build(); +// private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer-host").build(); +// private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-a-host").build(); +// private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-b-host").build(); +// private final List topology = Arrays.asList(writer, readerA, readerB); +// private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockContainer.getPluginService()).thenReturn(mockPluginService); +// writer.addAlias("writer-host"); +// newWriterHost.addAlias("new-writer-host"); +// readerA.addAlias("reader-a-host"); +// readerB.addAlias("reader-b-host"); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testReconnectToWriter_taskBReaderException() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockConnection); +// +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( +// final int failoverTimeoutMs, +// final int readTopologyIntervalMs, +// final int reconnectWriterIntervalMs) { +// ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( +// mockContainer, +// mockConnectionService, +// mockReaderFailoverHandler, +// properties, +// failoverTimeoutMs, +// readTopologyIntervalMs, +// reconnectWriterIntervalMs); +// +// ClusterAwareWriterFailoverHandler spyHandler = spy(handler); +// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); +// return spyHandler; +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: successfully re-connect to initial writer; return new connection. +// * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_SlowReaderA() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return new ReaderFailoverResult(mockReaderAConnection, readerA, true); +// }); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes. +// * TaskA: successfully re-connect to writer; return new connection. +// * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_taskBDefers() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. +// * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more +// * time than taskB. +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_SlowWriter() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(3, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. +// * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_taskADefers() throws SQLException { +// when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockNewWriterConnection; +// }); +// +// final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(4, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to failover timeout. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_failoverTimeout() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockNewWriterConnection; +// }); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// +// final long startTimeNano = System.nanoTime(); +// final WriterFailoverResult result = target.failover(topology); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to exception. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); +// when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index 61248faaa..147a81135 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -1,306 +1,306 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.util.monitoring; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; - -import java.sql.SQLException; -import java.util.Collections; -import java.util.HashSet; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.connection.ConnectionService; -import software.amazon.jdbc.util.events.EventPublisher; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; - -class MonitorServiceImplTest { - @Mock StorageService mockStorageService; - @Mock ConnectionService mockConnectionService; - @Mock TelemetryFactory mockTelemetryFactory; - @Mock TargetDriverDialect mockTargetDriverDialect; - @Mock Dialect mockDbDialect; - @Mock EventPublisher mockPublisher; - MonitorServiceImpl spyMonitorService; - private AutoCloseable closeable; - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); - doNothing().when(spyMonitorService).initCleanupThread(anyInt()); - - try { - doReturn(mockConnectionService).when(spyMonitorService) - .getConnectionService(any(), any(), any(), any(), any(), any(), any()); - } catch (SQLException e) { - Assertions.fail( - "Encountered exception while stubbing MonitorServiceImpl#getConnectionService: " + e.getMessage()); - } - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - spyMonitorService.releaseResources(); - } - - @Test - public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - monitor.state.set(MonitorState.ERROR); - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(newMonitor); - assertNotEquals(monitor, newMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, newMonitor.getState()); - } - - @Test - public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - 1, // heartbeat times out immediately - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(newMonitor); - assertNotEquals(monitor, newMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, newMonitor.getState()); - } - - @Test - public void testMonitorExpired() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - // checkMonitors() should detect the expiration timeout and stop/remove the monitor. - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - // monitor should have been removed when checkMonitors() was called. - assertNull(newMonitor); - } - - @Test - public void testMonitorMismatch() { - assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( - CustomEndpointMonitorImpl.class, - "testMonitor", - mockStorageService, - mockTelemetryFactory, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor - // service should detect this and throw an exception. - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - )); - } - - @Test - public void testRemove() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - assertNotNull(monitor); - - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); - assertEquals(monitor, removedMonitor); - assertEquals(MonitorState.RUNNING, monitor.getState()); - } - - @Test - public void testStopAndRemove() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - assertNotNull(monitor); - - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - spyMonitorService.stopAndRemove(NoOpMonitor.class, key); - assertNull(spyMonitorService.get(NoOpMonitor.class, key)); - assertEquals(MonitorState.STOPPED, monitor.getState()); - } - - static class NoOpMonitor extends AbstractMonitor { - protected NoOpMonitor( - MonitorService monitorService, - long terminationTimeoutSec) { - super(terminationTimeoutSec); - } - - @Override - public void monitor() { - // do nothing. - } - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.util.monitoring; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertNotEquals; +// import static org.junit.jupiter.api.Assertions.assertNotNull; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyInt; +// import static org.mockito.Mockito.doNothing; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// +// import java.sql.SQLException; +// import java.util.Collections; +// import java.util.HashSet; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.Assertions; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.connection.ConnectionService; +// import software.amazon.jdbc.util.events.EventPublisher; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// +// class MonitorServiceImplTest { +// @Mock StorageService mockStorageService; +// @Mock ConnectionService mockConnectionService; +// @Mock TelemetryFactory mockTelemetryFactory; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// @Mock Dialect mockDbDialect; +// @Mock EventPublisher mockPublisher; +// MonitorServiceImpl spyMonitorService; +// private AutoCloseable closeable; +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); +// doNothing().when(spyMonitorService).initCleanupThread(anyInt()); +// +// try { +// doReturn(mockConnectionService).when(spyMonitorService) +// .getConnectionService(any(), any(), any(), any(), any(), any(), any()); +// } catch (SQLException e) { +// Assertions.fail( +// "Encountered exception while stubbing MonitorServiceImpl#getConnectionService: " + e.getMessage()); +// } +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// spyMonitorService.releaseResources(); +// } +// +// @Test +// public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// monitor.state.set(MonitorState.ERROR); +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(newMonitor); +// assertNotEquals(monitor, newMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, newMonitor.getState()); +// } +// +// @Test +// public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// 1, // heartbeat times out immediately +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(newMonitor); +// assertNotEquals(monitor, newMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, newMonitor.getState()); +// } +// +// @Test +// public void testMonitorExpired() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// // checkMonitors() should detect the expiration timeout and stop/remove the monitor. +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// // monitor should have been removed when checkMonitors() was called. +// assertNull(newMonitor); +// } +// +// @Test +// public void testMonitorMismatch() { +// assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( +// CustomEndpointMonitorImpl.class, +// "testMonitor", +// mockStorageService, +// mockTelemetryFactory, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor +// // service should detect this and throw an exception. +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// )); +// } +// +// @Test +// public void testRemove() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// assertNotNull(monitor); +// +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); +// assertEquals(monitor, removedMonitor); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// } +// +// @Test +// public void testStopAndRemove() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// assertNotNull(monitor); +// +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// spyMonitorService.stopAndRemove(NoOpMonitor.class, key); +// assertNull(spyMonitorService.get(NoOpMonitor.class, key)); +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// } +// +// static class NoOpMonitor extends AbstractMonitor { +// protected NoOpMonitor( +// MonitorService monitorService, +// long terminationTimeoutSec) { +// super(terminationTimeoutSec); +// } +// +// @Override +// public void monitor() { +// // do nothing. +// } +// } +// } From f9871643a7ee2020627978344a8ac05b25cc0005 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 2 Sep 2025 11:48:59 -0700 Subject: [PATCH 27/54] Add ConnectionProvider arg to FullServicesContainer constructor --- .../amazon/jdbc/ConnectionPluginChainBuilder.java | 3 +-- .../amazon/jdbc/ConnectionPluginManager.java | 3 +-- .../src/main/java/software/amazon/jdbc/Driver.java | 3 ++- .../amazon/jdbc/ds/AwsWrapperDataSource.java | 3 ++- .../amazon/jdbc/util/FullServicesContainer.java | 3 +++ .../amazon/jdbc/util/FullServicesContainerImpl.java | 12 +++++++++++- .../amazon/jdbc/util/ServiceContainerUtility.java | 13 ++++++------- .../jdbc/util/connection/ConnectionServiceImpl.java | 3 ++- 8 files changed, 28 insertions(+), 15 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 411e40cd8..6c8475a26 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -148,8 +148,7 @@ public List getPlugins( final ConnectionProvider effectiveConnProvider, final PluginManagerService pluginManagerService, final Properties props, - @Nullable ConfigurationProfile configurationProfile) - throws SQLException { + @Nullable ConfigurationProfile configurationProfile) { List plugins; List pluginFactories; diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 8757e2c0a..6272072af 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -180,8 +180,7 @@ public void init( final FullServicesContainer servicesContainer, final Properties props, final PluginManagerService pluginManagerService, - @Nullable ConfigurationProfile configurationProfile) - throws SQLException { + @Nullable ConfigurationProfile configurationProfile) { this.props = props; this.servicesContainer = servicesContainer; diff --git a/wrapper/src/main/java/software/amazon/jdbc/Driver.java b/wrapper/src/main/java/software/amazon/jdbc/Driver.java index e4f5a14af..9daa92814 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/Driver.java +++ b/wrapper/src/main/java/software/amazon/jdbc/Driver.java @@ -240,7 +240,8 @@ public Connection connect(final String url, final Properties info) throws SQLExc } FullServicesContainer - servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); + servicesContainer = new FullServicesContainerImpl( + storageService, monitorService, defaultConnectionProvider, telemetryFactory); return new ConnectionWrapper( servicesContainer, diff --git a/wrapper/src/main/java/software/amazon/jdbc/ds/AwsWrapperDataSource.java b/wrapper/src/main/java/software/amazon/jdbc/ds/AwsWrapperDataSource.java index 0c54e018e..41f5fc23f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ds/AwsWrapperDataSource.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ds/AwsWrapperDataSource.java @@ -282,7 +282,8 @@ ConnectionWrapper createConnectionWrapper( final @Nullable ConfigurationProfile configurationProfile, final TelemetryFactory telemetryFactory) throws SQLException { FullServicesContainer - servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); + servicesContainer = new FullServicesContainerImpl( + storageService, monitorService, defaultProvider, telemetryFactory); return new ConnectionWrapper( servicesContainer, props, diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainer.java b/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainer.java index e276b4503..7b7857175 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainer.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainer.java @@ -17,6 +17,7 @@ package software.amazon.jdbc.util; import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; @@ -36,6 +37,8 @@ public interface FullServicesContainer { MonitorService getMonitorService(); + ConnectionProvider getDefaultConnectionProvider(); + TelemetryFactory getTelemetryFactory(); ConnectionPluginManager getConnectionPluginManager(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainerImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainerImpl.java index ef0a0fc53..db0ea3f57 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainerImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/FullServicesContainerImpl.java @@ -17,6 +17,7 @@ package software.amazon.jdbc.util; import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; @@ -27,6 +28,7 @@ public class FullServicesContainerImpl implements FullServicesContainer { private StorageService storageService; private MonitorService monitorService; + private ConnectionProvider defaultConnProvider; private TelemetryFactory telemetryFactory; private ConnectionPluginManager connectionPluginManager; private HostListProviderService hostListProviderService; @@ -36,12 +38,13 @@ public class FullServicesContainerImpl implements FullServicesContainer { public FullServicesContainerImpl( StorageService storageService, MonitorService monitorService, + ConnectionProvider defaultConnProvider, TelemetryFactory telemetryFactory, ConnectionPluginManager connectionPluginManager, HostListProviderService hostListProviderService, PluginService pluginService, PluginManagerService pluginManagerService) { - this(storageService, monitorService, telemetryFactory); + this(storageService, monitorService, defaultConnProvider, telemetryFactory); this.connectionPluginManager = connectionPluginManager; this.hostListProviderService = hostListProviderService; this.pluginService = pluginService; @@ -51,9 +54,11 @@ public FullServicesContainerImpl( public FullServicesContainerImpl( StorageService storageService, MonitorService monitorService, + ConnectionProvider defaultConnProvider, TelemetryFactory telemetryFactory) { this.storageService = storageService; this.monitorService = monitorService; + this.defaultConnProvider = defaultConnProvider; this.telemetryFactory = telemetryFactory; } @@ -67,6 +72,11 @@ public MonitorService getMonitorService() { return this.monitorService; } + @Override + public ConnectionProvider getDefaultConnectionProvider() { + return this.defaultConnProvider; + } + @Override public TelemetryFactory getTelemetryFactory() { return this.telemetryFactory; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java index e0e57b793..edd922506 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java @@ -60,20 +60,18 @@ public static ServiceContainerUtility getInstance() { public static FullServicesContainer createServiceContainer( StorageService storageService, MonitorService monitorService, + ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, String originalUrl, String targetDriverProtocol, TargetDriverDialect driverDialect, Dialect dbDialect, - Properties props) throws SQLException { - final TargetDriverHelper helper = new TargetDriverHelper(); - final java.sql.Driver driver = helper.getTargetDriver(originalUrl, props); - final ConnectionProvider connProvider = new DriverConnectionProvider(driver); - + Properties props) { FullServicesContainer - servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); + servicesContainer = new FullServicesContainerImpl( + storageService, monitorService, connectionProvider, telemetryFactory); ConnectionPluginManager pluginManager = new ConnectionPluginManager( - connProvider, + connectionProvider, null, null, telemetryFactory); @@ -92,6 +90,7 @@ public static FullServicesContainer createServiceContainer( return new FullServicesContainerImpl( storageService, monitorService, + connectionProvider, telemetryFactory, pluginManager, partialPluginService, diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 0e506a112..52744d494 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -50,7 +50,8 @@ public ConnectionServiceImpl( this.targetDriverProtocol = targetDriverProtocol; FullServicesContainer - servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); + servicesContainer = new FullServicesContainerImpl( + storageService, monitorService, connectionProvider, telemetryFactory); this.pluginManager = new ConnectionPluginManager( connectionProvider, null, From 47d2e2079303a8e4594a92f4b5bc9346054a1627 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 2 Sep 2025 15:15:30 -0700 Subject: [PATCH 28/54] Utils.getWriter, PluginService.getDefaultConnectionProvider, isolate new plugin services --- .../amazon/jdbc/PartialPluginService.java | 29 +- .../software/amazon/jdbc/PluginService.java | 2 + .../amazon/jdbc/PluginServiceImpl.java | 16 +- .../plugin/AuroraConnectionTrackerPlugin.java | 17 +- ...AuroraInitialConnectionStrategyPlugin.java | 12 +- .../ClusterAwareReaderFailoverHandler.java | 12 +- .../ClusterAwareWriterFailoverHandler.java | 40 +- .../failover/FailoverConnectionPlugin.java | 32 +- .../ReadWriteSplittingPlugin.java | 9 +- .../plugin/staledns/AuroraStaleDnsHelper.java | 11 +- .../java/software/amazon/jdbc/util/Utils.java | 14 + .../connection/ConnectionServiceImpl.java | 22 +- ...ClusterAwareReaderFailoverHandlerTest.java | 802 +++++++++--------- ...ClusterAwareWriterFailoverHandlerTest.java | 746 ++++++++-------- 14 files changed, 859 insertions(+), 905 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 880c6f0dc..9b36ea6d0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -83,13 +83,15 @@ public class PartialPluginService implements PluginService, CanReleaseResources, public PartialPluginService( @NonNull final FullServicesContainer servicesContainer, + @NonNull final ConnectionProvider defaultConnectionProvider, @NonNull final Properties props, @NonNull final String originalUrl, @NonNull final String targetDriverProtocol, @NonNull final TargetDriverDialect targetDriverDialect, - @NonNull final Dialect dbDialect) { + @NonNull final Dialect dbDialect) throws SQLException { this( servicesContainer, + defaultConnectionProvider, new ExceptionManager(), props, originalUrl, @@ -101,15 +103,15 @@ public PartialPluginService( public PartialPluginService( @NonNull final FullServicesContainer servicesContainer, + @NonNull final ConnectionProvider defaultConnectionProvider, @NonNull final ExceptionManager exceptionManager, @NonNull final Properties props, @NonNull final String originalUrl, @NonNull final String targetDriverProtocol, @NonNull final TargetDriverDialect targetDriverDialect, @NonNull final Dialect dbDialect, - @Nullable final ConfigurationProfile configurationProfile) { + @Nullable final ConfigurationProfile configurationProfile) throws SQLException { this.servicesContainer = servicesContainer; - this.pluginManager = servicesContainer.getConnectionPluginManager(); this.props = props; this.originalUrl = originalUrl; this.driverProtocol = targetDriverProtocol; @@ -117,6 +119,11 @@ public PartialPluginService( this.dbDialect = dbDialect; this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; + + this.pluginManager = new ConnectionPluginManager( + defaultConnectionProvider, null, null, servicesContainer.getTelemetryFactory()); + this.pluginManager.init(this.servicesContainer, this.props, this, this.configurationProfile); + this.servicesContainer.setConnectionPluginManager(pluginManager); this.connectionProviderManager = new ConnectionProviderManager( this.pluginManager.getDefaultConnProvider(), this.pluginManager.getEffectiveConnProvider()); @@ -149,7 +156,7 @@ public HostSpec getCurrentHostSpec() { throw new RuntimeException(Messages.get("PluginServiceImpl.hostListEmpty")); } - this.currentHostSpec = this.getWriter(this.getAllHosts()); + this.currentHostSpec = Utils.getWriter(this.getAllHosts()); final List allowedHosts = this.getHosts(); if (!Utils.containsUrl(allowedHosts, this.currentHostSpec.getUrl())) { throw new RuntimeException( @@ -215,21 +222,17 @@ public HostRole getHostRole(Connection conn) throws SQLException { return this.hostListProvider.getHostRole(conn); } - private HostSpec getWriter(final @NonNull List hosts) { - for (final HostSpec hostSpec : hosts) { - if (hostSpec.getRole() == HostRole.WRITER) { - return hostSpec; - } - } - return null; - } - @Override @Deprecated public ConnectionProvider getConnectionProvider() { return this.pluginManager.defaultConnProvider; } + @Override + public ConnectionProvider getDefaultConnectionProvider() { + return this.connectionProviderManager.getDefaultProvider(); + } + public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = this.connectionProviderManager.getConnectionProvider(this.driverProtocol, host, props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index fe4dc8cc6..6c9180650 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -242,6 +242,8 @@ Connection forceConnect( @Deprecated ConnectionProvider getConnectionProvider(); + ConnectionProvider getDefaultConnectionProvider(); + boolean isPooledConnectionProvider(HostSpec host, Properties props); String getDriverProtocol(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 38695c144..43a8d0e28 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -180,7 +180,7 @@ public HostSpec getCurrentHostSpec() { throw new RuntimeException(Messages.get("PluginServiceImpl.hostListEmpty")); } - this.currentHostSpec = this.getWriter(this.getAllHosts()); + this.currentHostSpec = Utils.getWriter(this.getAllHosts()); final List allowedHosts = this.getHosts(); if (!Utils.containsUrl(allowedHosts, this.currentHostSpec.getUrl())) { throw new RuntimeException( @@ -243,21 +243,17 @@ public HostRole getHostRole(Connection conn) throws SQLException { return this.hostListProvider.getHostRole(conn); } - private HostSpec getWriter(final @NonNull List hosts) { - for (final HostSpec hostSpec : hosts) { - if (hostSpec.getRole() == HostRole.WRITER) { - return hostSpec; - } - } - return null; - } - @Override @Deprecated public ConnectionProvider getConnectionProvider() { return this.pluginManager.defaultConnProvider; } + @Override + public ConnectionProvider getDefaultConnectionProvider() { + return this.connectionProviderManager.getDefaultProvider(); + } + public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = this.connectionProviderManager.getConnectionProvider(this.driverProtocol, host, props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java index 6d851b92f..dae5c3558 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java @@ -21,15 +21,12 @@ import java.util.Collections; import java.util.EnumSet; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Logger; -import org.checkerframework.checker.nullness.qual.NonNull; -import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.JdbcMethod; @@ -38,6 +35,7 @@ import software.amazon.jdbc.plugin.failover.FailoverSQLException; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.Utils; public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin { @@ -156,7 +154,7 @@ private void checkWriterChanged(boolean needRefreshHostLists) { // do nothing } } - final HostSpec hostSpecAfterFailover = this.getWriter(this.pluginService.getAllHosts()); + final HostSpec hostSpecAfterFailover = Utils.getWriter(this.pluginService.getAllHosts()); if (this.currentWriter == null) { this.currentWriter = hostSpecAfterFailover; @@ -174,7 +172,7 @@ private void checkWriterChanged(boolean needRefreshHostLists) { private void rememberWriter() { if (this.currentWriter == null || this.needUpdateCurrentWriter) { - this.currentWriter = this.getWriter(this.pluginService.getAllHosts()); + this.currentWriter = Utils.getWriter(this.pluginService.getAllHosts()); this.needUpdateCurrentWriter = false; } } @@ -191,13 +189,4 @@ public void notifyNodeListChanged(final Map> } } } - - private HostSpec getWriter(final @NonNull List hosts) { - for (final HostSpec hostSpec : hosts) { - if (hostSpec.getRole() == HostRole.WRITER) { - return hostSpec; - } - } - return null; - } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java index 7b1eedbad..01b9bf8e9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java @@ -37,6 +37,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlugin { @@ -187,7 +188,7 @@ private Connection getVerifiedWriterConnection( writerCandidate = null; try { - writerCandidate = this.getWriter(); + writerCandidate = Utils.getWriter(this.pluginService.getAllHosts()); if (writerCandidate == null || this.rdsUtils.isRdsClusterDns(writerCandidate.getHost())) { @@ -362,15 +363,6 @@ private void delay(final long delayMs) { } } - private HostSpec getWriter() { - for (final HostSpec host : this.pluginService.getAllHosts()) { - if (host.getRole() == HostRole.WRITER) { - return host; - } - } - return null; - } - private HostSpec getReader(final Properties props) throws SQLException { final String strategy = READER_HOST_SELECTOR_STRATEGY.getString(props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 35f16b35f..4b2373c19 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -41,6 +41,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.FullServicesContainerImpl; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.Utils; @@ -396,9 +397,16 @@ private ReaderFailoverResult getNextResult(final CompletionService currentTopology) } } - private static HostSpec getWriter(final List topology) { - if (topology == null || topology.isEmpty()) { - return null; - } - - for (final HostSpec host : topology) { - if (host.getRole() == HostRole.WRITER) { - return host; - } - } - return null; - } - private void submitTasks( final List currentTopology, final ExecutorService executorService, final CompletionService completionService, - final boolean singleTask) { - final HostSpec writerHost = getWriter(currentTopology); + final boolean singleTask) throws SQLException { + final HostSpec writerHost = Utils.getWriter(currentTopology); if (!singleTask) { completionService.submit( new ReconnectToWriterHandler( @@ -189,11 +176,18 @@ private void submitTasks( executorService.shutdown(); } - protected PluginService getNewPluginService() { - // Each task should get its own PluginService since they execute concurrently and PluginService was not designed to - // be thread-safe. + // Each task should get its own PluginService since they execute concurrently and PluginService was not designed to + // be thread-safe. + protected PluginService getNewPluginService() throws SQLException { + FullServicesContainer newServicesContainer = new FullServicesContainerImpl( + this.servicesContainer.getStorageService(), + this.servicesContainer.getMonitorService(), + this.servicesContainer.getTelemetryFactory() + ); + return new PartialPluginService( - this.servicesContainer, + newServicesContainer, + this.pluginService.getDefaultConnectionProvider(), this.initialConnectionProps, this.pluginService.getOriginalUrl(), this.pluginService.getDriverProtocol(), @@ -248,7 +242,7 @@ private void logTaskSuccess(final WriterFailoverResult result) { return; } - final HostSpec writerHost = getWriter(topology); + final HostSpec writerHost = Utils.getWriter(topology); final String newWriterHost = writerHost == null ? null : writerHost.getUrl(); if (result.isNewHost()) { LOGGER.fine( @@ -357,7 +351,7 @@ public WriterFailoverResult call() { } private boolean isCurrentHostWriter(final List latestTopology) { - final HostSpec latestWriter = getWriter(latestTopology); + final HostSpec latestWriter = Utils.getWriter(latestTopology); final Set latestWriterAllAliases = latestWriter.asAliases(); final Set currentAliases = this.originalWriterHost.asAliases(); @@ -494,7 +488,7 @@ private boolean refreshTopologyAndConnectToNewWriter() throws InterruptedExcepti } else { this.currentTopology = topology; - final HostSpec writerCandidate = getWriter(this.currentTopology); + final HostSpec writerCandidate = Utils.getWriter(this.currentTopology); if (allowOldWriter || !isSame(writerCandidate, this.originalWriterHost)) { // new writer is available, and it's different from the previous writer diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 19eb59a48..ab3804be8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -31,7 +31,7 @@ import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; -import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.AwsWrapperProperty; import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.DriverConnectionProvider; @@ -450,23 +450,6 @@ private void invalidInvocationOnClosedConnection() throws SQLException { } } - private HostSpec getCurrentWriter() throws SQLException { - final List topology = this.pluginService.getAllHosts(); - if (topology == null) { - return null; - } - return getWriter(topology); - } - - private HostSpec getWriter(final @NonNull List hosts) { - for (final HostSpec hostSpec : hosts) { - if (hostSpec.getRole() == HostRole.WRITER) { - return hostSpec; - } - } - return null; - } - protected void updateTopology(final boolean forceUpdate) throws SQLException { final Connection connection = this.pluginService.getCurrentConnection(); if (!isFailoverEnabled() || connection == null || connection.isClosed()) { @@ -618,8 +601,11 @@ protected void dealWithIllegalStateException( * @param failedHost The host with network errors. * @throws SQLException if an error occurs */ - protected void failover(final HostSpec failedHost) throws SQLException { - this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); + protected void failover(@Nullable final HostSpec failedHost) throws SQLException { + if (failedHost != null) { + this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); + } + if (this.connectionService == null) { this.connectionService = getConnectionService(); } @@ -794,7 +780,7 @@ protected void failoverWriter() throws SQLException { } List hosts = failoverResult.getTopology(); - final HostSpec writerHostSpec = getWriter(hosts); + final HostSpec writerHostSpec = Utils.getWriter(hosts); if (writerHostSpec == null) { throwFailoverFailedException( Messages.get( @@ -898,9 +884,9 @@ protected void pickNewConnection() throws SQLException { if (this.pluginService.getCurrentConnection() == null && !shouldAttemptReaderConnection()) { try { - connectTo(getCurrentWriter()); + connectTo(Utils.getWriter(this.pluginService.getAllHosts())); } catch (final SQLException e) { - failover(getCurrentWriter()); + failover(Utils.getWriter(this.pluginService.getAllHosts())); } } else { failover(this.pluginService.getCurrentHostSpec()); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index 90b170f98..22dba8a0a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -475,14 +475,7 @@ private void initializeReaderConnection(final @NonNull List hosts) thr } private HostSpec getWriter(final @NonNull List hosts) throws SQLException { - HostSpec writerHost = null; - for (final HostSpec hostSpec : hosts) { - if (HostRole.WRITER.equals(hostSpec.getRole())) { - writerHost = hostSpec; - break; - } - } - + HostSpec writerHost = Utils.getWriter(hosts); if (writerHost == null) { logAndThrowException(Messages.get("ReadWriteSplittingPlugin.noWriterFound")); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java index ffcaaef82..93b62536d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java @@ -99,7 +99,7 @@ public Connection getVerifiedConnection( LOGGER.finest(() -> Utils.logTopology(this.pluginService.getAllHosts())); if (this.writerHostSpec == null) { - final HostSpec writerCandidate = this.getWriter(); + final HostSpec writerCandidate = Utils.getWriter(this.pluginService.getAllHosts()); if (writerCandidate != null && this.rdsUtils.isRdsClusterDns(writerCandidate.getHost())) { return null; } @@ -181,13 +181,4 @@ public void notifyNodeListChanged(final Map> } } } - - private HostSpec getWriter() { - for (final HostSpec host : this.pluginService.getAllHosts()) { - if (host.getRole() == HostRole.WRITER) { - return host; - } - } - return null; - } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java b/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java index 4927ac323..d91538d05 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java @@ -19,6 +19,7 @@ import java.util.Collection; import java.util.List; import org.checkerframework.checker.nullness.qual.Nullable; +import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; public class Utils { @@ -36,6 +37,19 @@ public static boolean containsUrl(final List hosts, String url) { return false; } + public static @Nullable HostSpec getWriter(final List topology) { + if (topology == null || topology.isEmpty()) { + return null; + } + + for (final HostSpec host : topology) { + if (host.getRole() == HostRole.WRITER) { + return host; + } + } + return null; + } + public static String logTopology(final @Nullable List hosts) { return logTopology(hosts, null); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 0e506a112..8b2c23e97 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -19,7 +19,6 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.Properties; -import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PartialPluginService; @@ -33,8 +32,6 @@ import software.amazon.jdbc.util.telemetry.TelemetryFactory; public class ConnectionServiceImpl implements ConnectionService { - protected final String targetDriverProtocol; - protected final ConnectionPluginManager pluginManager; protected final PluginService pluginService; public ConnectionServiceImpl( @@ -47,33 +44,22 @@ public ConnectionServiceImpl( TargetDriverDialect driverDialect, Dialect dbDialect, Properties props) throws SQLException { - this.targetDriverProtocol = targetDriverProtocol; - FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); - this.pluginManager = new ConnectionPluginManager( - connectionProvider, - null, - null, - telemetryFactory); - servicesContainer.setConnectionPluginManager(this.pluginManager); - - PartialPluginService partialPluginService = new PartialPluginService( + this.pluginService = new PartialPluginService( servicesContainer, + connectionProvider, props, originalUrl, - this.targetDriverProtocol, + targetDriverProtocol, driverDialect, dbDialect ); - - this.pluginService = partialPluginService; - this.pluginManager.init(servicesContainer, props, partialPluginService, null); } @Override public Connection open(HostSpec hostSpec, Properties props) throws SQLException { - return this.pluginManager.forceConnect(this.targetDriverProtocol, hostSpec, props, true, null); + return this.pluginService.forceConnect(hostSpec, props); } @Override diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java index 966ff6ec4..a6351a849 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java @@ -1,401 +1,401 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.ConnectionPluginManager; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionService; - -class ClusterAwareReaderFailoverHandlerTest { - @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; - @Mock PluginService mockPluginService; - @Mock ConnectionPluginManager mockPluginManager; - @Mock Connection mockConnection; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final List defaultHosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader2").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader3").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader4").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader5").port(1234).role(HostRole.READER).build() - ); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); - when(mockContainer.getPluginService()).thenReturn(mockPluginService); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testFailover() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: successful connection for host at index 4 - final List hosts = defaultHosts; - final int currentHostIndex = 2; - final int successHostIndex = 4; - for (int i = 0; i < hosts.size(); i++) { - if (i != successHostIndex) { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockConnectionService.open(hosts.get(i), properties)) - .thenThrow(exception); - when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); - } else { - when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); - } - } - - when(mockPluginService.getTargetDriverDialect()).thenReturn(null); - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(successHostIndex), result.getHost()); - - final HostSpec successHost = hosts.get(successHostIndex); - final Map availabilityMap = target.getHostAvailabilityMap(); - Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); - assertTrue(unavailableHosts.size() >= 4); - assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); - } - - private Set getHostsWithGivenAvailability( - Map availabilityMap, HostAvailability availability) { - return availabilityMap.entrySet().stream() - .filter((entry) -> availability.equals(entry.getValue())) - .map(Map.Entry::getKey) - .collect(Collectors.toSet()); - } - - @Test - public void testFailover_timeout() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: failure to get reader since process is limited to 5s and each attempt - // to connect takes 20s - final List hosts = defaultHosts; - final int currentHostIndex = 2; - for (HostSpec host : hosts) { - when(mockConnectionService.open(host, properties)) - .thenAnswer((Answer) invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - } - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); - - final long startTimeNano = System.nanoTime(); - final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() { - ClusterAwareReaderFailoverHandler handler = - spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); - doReturn(mockPluginService).when(handler).getNewPluginService(); - return handler; - } - - private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( - int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) { - ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( - mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); - ClusterAwareReaderFailoverHandler spyHandler = spy(handler); - doReturn(mockPluginService).when(spyHandler).getNewPluginService(); - return spyHandler; - } - - @Test - public void testFailover_nullOrEmptyHostList() throws SQLException { - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final HostSpec currentHost = - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); - - ReaderFailoverResult result = target.failover(null, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - final List hosts = new ArrayList<>(); - result = target.failover(hosts, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionSuccess() throws SQLException { - // even number of connection attempts - // first connection attempt to return succeeds, second attempt cancelled - // expected test result: successful connection for host at index 2 - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - final HostSpec slowHost = hosts.get(1); - final HostSpec fastHost = hosts.get(2); - when(mockConnectionService.open(slowHost, properties)) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(2), result.getHost()); - - Map availabilityMap = target.getHostAvailabilityMap(); - assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); - assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); - } - - @Test - public void testGetReader_connectionFailure() throws SQLException { - // odd number of connection attempts - // first connection attempt to return fails - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) - when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionAttemptsTimeout() throws SQLException { - // connection attempts time out before they can succeed - // first connection attempt to return times out - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - when(mockConnectionService.open(any(), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - try { - Thread.sleep(5000); - } catch (InterruptedException exception) { - // ignore - } - return mockConnection; - }); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetHostTuplesByPriority() { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final List hostsByPriority = target.getHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting a writer - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.WRITER) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testGetReaderTuplesByPriority() { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testHostFailoverStrictReaderEnabled() { - - final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(); - final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(); - final List hosts = Arrays.asList(writer, reader); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = - getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); - - // The writer is included because the original writer has likely become a reader. - List expectedHostsByPriority = Arrays.asList(reader, writer); - - List hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. - reader.setAvailability(HostAvailability.NOT_AVAILABLE); - expectedHostsByPriority = Arrays.asList(writer, reader); - - hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Writer node will only be picked if it is the only node in topology; - List expectedWriterHost = Collections.singletonList(writer); - - hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); - assertEquals(expectedWriterHost, hostsByPriority); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.when; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Map; +// import java.util.Properties; +// import java.util.Set; +// import java.util.concurrent.TimeUnit; +// import java.util.stream.Collectors; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.ConnectionPluginManager; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.connection.ConnectionService; +// +// class ClusterAwareReaderFailoverHandlerTest { +// @Mock FullServicesContainer mockContainer; +// @Mock ConnectionService mockConnectionService; +// @Mock PluginService mockPluginService; +// @Mock ConnectionPluginManager mockPluginManager; +// @Mock Connection mockConnection; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final List defaultHosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader2").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader3").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader4").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader5").port(1234).role(HostRole.READER).build() +// ); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); +// when(mockContainer.getPluginService()).thenReturn(mockPluginService); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testFailover() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: successful connection for host at index 4 +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// final int successHostIndex = 4; +// for (int i = 0; i < hosts.size(); i++) { +// if (i != successHostIndex) { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockConnectionService.open(hosts.get(i), properties)) +// .thenThrow(exception); +// when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); +// } else { +// when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); +// } +// } +// +// when(mockPluginService.getTargetDriverDialect()).thenReturn(null); +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(successHostIndex), result.getHost()); +// +// final HostSpec successHost = hosts.get(successHostIndex); +// final Map availabilityMap = target.getHostAvailabilityMap(); +// Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); +// assertTrue(unavailableHosts.size() >= 4); +// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); +// } +// +// private Set getHostsWithGivenAvailability( +// Map availabilityMap, HostAvailability availability) { +// return availabilityMap.entrySet().stream() +// .filter((entry) -> availability.equals(entry.getValue())) +// .map(Map.Entry::getKey) +// .collect(Collectors.toSet()); +// } +// +// @Test +// public void testFailover_timeout() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: failure to get reader since process is limited to 5s and each attempt +// // to connect takes 20s +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// for (HostSpec host : hosts) { +// when(mockConnectionService.open(host, properties)) +// .thenAnswer((Answer) invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// } +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); +// +// final long startTimeNano = System.nanoTime(); +// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() { +// ClusterAwareReaderFailoverHandler handler = +// spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); +// doReturn(mockPluginService).when(handler).getNewPluginService(); +// return handler; +// } +// +// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( +// int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) { +// ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( +// mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); +// ClusterAwareReaderFailoverHandler spyHandler = spy(handler); +// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); +// return spyHandler; +// } +// +// @Test +// public void testFailover_nullOrEmptyHostList() throws SQLException { +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final HostSpec currentHost = +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); +// +// ReaderFailoverResult result = target.failover(null, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// final List hosts = new ArrayList<>(); +// result = target.failover(hosts, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionSuccess() throws SQLException { +// // even number of connection attempts +// // first connection attempt to return succeeds, second attempt cancelled +// // expected test result: successful connection for host at index 2 +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// final HostSpec slowHost = hosts.get(1); +// final HostSpec fastHost = hosts.get(2); +// when(mockConnectionService.open(slowHost, properties)) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(2), result.getHost()); +// +// Map availabilityMap = target.getHostAvailabilityMap(); +// assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); +// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); +// } +// +// @Test +// public void testGetReader_connectionFailure() throws SQLException { +// // odd number of connection attempts +// // first connection attempt to return fails +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) +// when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionAttemptsTimeout() throws SQLException { +// // connection attempts time out before they can succeed +// // first connection attempt to return times out +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// when(mockConnectionService.open(any(), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// try { +// Thread.sleep(5000); +// } catch (InterruptedException exception) { +// // ignore +// } +// return mockConnection; +// }); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetHostTuplesByPriority() { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final List hostsByPriority = target.getHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting a writer +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.WRITER) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testGetReaderTuplesByPriority() { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testHostFailoverStrictReaderEnabled() { +// +// final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(); +// final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(); +// final List hosts = Arrays.asList(writer, reader); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = +// getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); +// +// // The writer is included because the original writer has likely become a reader. +// List expectedHostsByPriority = Arrays.asList(reader, writer); +// +// List hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. +// reader.setAvailability(HostAvailability.NOT_AVAILABLE); +// expectedHostsByPriority = Arrays.asList(writer, reader); +// +// hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Writer node will only be picked if it is the only node in topology; +// List expectedWriterHost = Collections.singletonList(writer); +// +// hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); +// assertEquals(expectedWriterHost, hostsByPriority); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java index 1ad394010..097b68d67 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java @@ -1,373 +1,373 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.refEq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionService; - -class ClusterAwareWriterFailoverHandlerTest { - @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; - @Mock PluginService mockPluginService; - @Mock Connection mockConnection; - @Mock ReaderFailoverHandler mockReaderFailoverHandler; - @Mock Connection mockWriterConnection; - @Mock Connection mockNewWriterConnection; - @Mock Connection mockReaderAConnection; - @Mock Connection mockReaderBConnection; - @Mock Dialect mockDialect; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("new-writer-host").build(); - private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer-host").build(); - private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-a-host").build(); - private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-b-host").build(); - private final List topology = Arrays.asList(writer, readerA, readerB); - private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - when(mockContainer.getPluginService()).thenReturn(mockPluginService); - writer.addAlias("writer-host"); - newWriterHost.addAlias("new-writer-host"); - readerA.addAlias("reader-a-host"); - readerB.addAlias("reader-b-host"); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testReconnectToWriter_taskBReaderException() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockConnection); - - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( - final int failoverTimeoutMs, - final int readTopologyIntervalMs, - final int reconnectWriterIntervalMs) { - ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( - mockContainer, - mockConnectionService, - mockReaderFailoverHandler, - properties, - failoverTimeoutMs, - readTopologyIntervalMs, - reconnectWriterIntervalMs); - - ClusterAwareWriterFailoverHandler spyHandler = spy(handler); - doReturn(mockPluginService).when(spyHandler).getNewPluginService(); - return spyHandler; - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: successfully re-connect to initial writer; return new connection. - * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_SlowReaderA() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return new ReaderFailoverResult(mockReaderAConnection, readerA, true); - }); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes. - * TaskA: successfully re-connect to writer; return new connection. - * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_taskBDefers() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. - * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more - * time than taskB. - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_SlowWriter() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(3, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. - * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_taskADefers() throws SQLException { - when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockNewWriterConnection; - }); - - final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(4, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to failover timeout. - * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_failoverTimeout() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockNewWriterConnection; - }); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - - final long startTimeNano = System.nanoTime(); - final WriterFailoverResult result = target.failover(topology); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to exception. - * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); - when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.ArgumentMatchers.refEq; +// import static org.mockito.Mockito.atLeastOnce; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.Arrays; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.ArgumentMatchers; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.connection.ConnectionService; +// +// class ClusterAwareWriterFailoverHandlerTest { +// @Mock FullServicesContainer mockContainer; +// @Mock ConnectionService mockConnectionService; +// @Mock PluginService mockPluginService; +// @Mock Connection mockConnection; +// @Mock ReaderFailoverHandler mockReaderFailoverHandler; +// @Mock Connection mockWriterConnection; +// @Mock Connection mockNewWriterConnection; +// @Mock Connection mockReaderAConnection; +// @Mock Connection mockReaderBConnection; +// @Mock Dialect mockDialect; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("new-writer-host").build(); +// private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer-host").build(); +// private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-a-host").build(); +// private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-b-host").build(); +// private final List topology = Arrays.asList(writer, readerA, readerB); +// private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockContainer.getPluginService()).thenReturn(mockPluginService); +// writer.addAlias("writer-host"); +// newWriterHost.addAlias("new-writer-host"); +// readerA.addAlias("reader-a-host"); +// readerB.addAlias("reader-b-host"); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testReconnectToWriter_taskBReaderException() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockConnection); +// +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( +// final int failoverTimeoutMs, +// final int readTopologyIntervalMs, +// final int reconnectWriterIntervalMs) { +// ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( +// mockContainer, +// mockConnectionService, +// mockReaderFailoverHandler, +// properties, +// failoverTimeoutMs, +// readTopologyIntervalMs, +// reconnectWriterIntervalMs); +// +// ClusterAwareWriterFailoverHandler spyHandler = spy(handler); +// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); +// return spyHandler; +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: successfully re-connect to initial writer; return new connection. +// * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_SlowReaderA() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return new ReaderFailoverResult(mockReaderAConnection, readerA, true); +// }); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes. +// * TaskA: successfully re-connect to writer; return new connection. +// * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_taskBDefers() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. +// * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more +// * time than taskB. +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_SlowWriter() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(3, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. +// * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_taskADefers() throws SQLException { +// when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockNewWriterConnection; +// }); +// +// final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(4, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to failover timeout. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_failoverTimeout() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockNewWriterConnection; +// }); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// +// final long startTimeNano = System.nanoTime(); +// final WriterFailoverResult result = target.failover(topology); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to exception. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); +// when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// } From f7b8db9a1207c8ee25426db10ddce43b8f532cf9 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Tue, 2 Sep 2025 17:29:05 -0700 Subject: [PATCH 29/54] Fix failing integration tests --- .../amazon/jdbc/PartialPluginService.java | 20 ++++++----------- .../ClusterAwareReaderFailoverHandler.java | 10 +++++++-- .../ClusterAwareWriterFailoverHandler.java | 10 +++++++-- .../connection/ConnectionServiceImpl.java | 22 +++++++++++++++---- ...ClusterAwareReaderFailoverHandlerTest.java | 2 +- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 9b36ea6d0..34391c3a5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -83,15 +83,13 @@ public class PartialPluginService implements PluginService, CanReleaseResources, public PartialPluginService( @NonNull final FullServicesContainer servicesContainer, - @NonNull final ConnectionProvider defaultConnectionProvider, @NonNull final Properties props, @NonNull final String originalUrl, @NonNull final String targetDriverProtocol, @NonNull final TargetDriverDialect targetDriverDialect, - @NonNull final Dialect dbDialect) throws SQLException { + @NonNull final Dialect dbDialect) { this( servicesContainer, - defaultConnectionProvider, new ExceptionManager(), props, originalUrl, @@ -103,15 +101,19 @@ public PartialPluginService( public PartialPluginService( @NonNull final FullServicesContainer servicesContainer, - @NonNull final ConnectionProvider defaultConnectionProvider, @NonNull final ExceptionManager exceptionManager, @NonNull final Properties props, @NonNull final String originalUrl, @NonNull final String targetDriverProtocol, @NonNull final TargetDriverDialect targetDriverDialect, @NonNull final Dialect dbDialect, - @Nullable final ConfigurationProfile configurationProfile) throws SQLException { + @Nullable final ConfigurationProfile configurationProfile) { this.servicesContainer = servicesContainer; + this.servicesContainer.setHostListProviderService(this); + this.servicesContainer.setPluginService(this); + this.servicesContainer.setPluginManagerService(this); + + this.pluginManager = servicesContainer.getConnectionPluginManager(); this.props = props; this.originalUrl = originalUrl; this.driverProtocol = targetDriverProtocol; @@ -120,10 +122,6 @@ public PartialPluginService( this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; - this.pluginManager = new ConnectionPluginManager( - defaultConnectionProvider, null, null, servicesContainer.getTelemetryFactory()); - this.pluginManager.init(this.servicesContainer, this.props, this, this.configurationProfile); - this.servicesContainer.setConnectionPluginManager(pluginManager); this.connectionProviderManager = new ConnectionProviderManager( this.pluginManager.getDefaultConnProvider(), this.pluginManager.getEffectiveConnProvider()); @@ -132,10 +130,6 @@ public PartialPluginService( ? this.configurationProfile.getExceptionHandler() : null; - servicesContainer.setHostListProviderService(this); - servicesContainer.setPluginService(this); - servicesContainer.setPluginManagerService(this); - HostListProviderSupplier supplier = this.dbDialect.getHostListProvider(); this.hostListProvider = supplier.getProvider(this.props, this.originalUrl, this.servicesContainer); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 4b2373c19..d0512cf55 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -34,6 +34,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.logging.Logger; +import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PartialPluginService; @@ -404,15 +405,20 @@ protected PluginService getNewPluginService() throws SQLException { this.servicesContainer.getTelemetryFactory() ); - return new PartialPluginService( + ConnectionPluginManager pluginManager = new ConnectionPluginManager( + this.pluginService.getDefaultConnectionProvider(), null, null, servicesContainer.getTelemetryFactory()); + newServicesContainer.setConnectionPluginManager(pluginManager); + PartialPluginService pluginService = new PartialPluginService( newServicesContainer, - this.pluginService.getDefaultConnectionProvider(), this.props, this.pluginService.getOriginalUrl(), this.pluginService.getDriverProtocol(), this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect() ); + + pluginManager.init(newServicesContainer, this.props, pluginService, null); + return pluginService; } private static class ConnectionAttemptTask implements Callable { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index b93a2292a..fcdd52a7c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -32,6 +32,7 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; +import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; @@ -185,15 +186,20 @@ protected PluginService getNewPluginService() throws SQLException { this.servicesContainer.getTelemetryFactory() ); - return new PartialPluginService( + ConnectionPluginManager pluginManager = new ConnectionPluginManager( + this.pluginService.getDefaultConnectionProvider(), null, null, servicesContainer.getTelemetryFactory()); + newServicesContainer.setConnectionPluginManager(pluginManager); + PartialPluginService pluginService = new PartialPluginService( newServicesContainer, - this.pluginService.getDefaultConnectionProvider(), this.initialConnectionProps, this.pluginService.getOriginalUrl(), this.pluginService.getDriverProtocol(), this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect() ); + + pluginManager.init(newServicesContainer, this.initialConnectionProps, pluginService, null); + return pluginService; } private WriterFailoverResult getNextResult( diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 8b2c23e97..0e506a112 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -19,6 +19,7 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.Properties; +import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PartialPluginService; @@ -32,6 +33,8 @@ import software.amazon.jdbc.util.telemetry.TelemetryFactory; public class ConnectionServiceImpl implements ConnectionService { + protected final String targetDriverProtocol; + protected final ConnectionPluginManager pluginManager; protected final PluginService pluginService; public ConnectionServiceImpl( @@ -44,22 +47,33 @@ public ConnectionServiceImpl( TargetDriverDialect driverDialect, Dialect dbDialect, Properties props) throws SQLException { + this.targetDriverProtocol = targetDriverProtocol; + FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, telemetryFactory); - this.pluginService = new PartialPluginService( - servicesContainer, + this.pluginManager = new ConnectionPluginManager( connectionProvider, + null, + null, + telemetryFactory); + servicesContainer.setConnectionPluginManager(this.pluginManager); + + PartialPluginService partialPluginService = new PartialPluginService( + servicesContainer, props, originalUrl, - targetDriverProtocol, + this.targetDriverProtocol, driverDialect, dbDialect ); + + this.pluginService = partialPluginService; + this.pluginManager.init(servicesContainer, props, partialPluginService, null); } @Override public Connection open(HostSpec hostSpec, Properties props) throws SQLException { - return this.pluginService.forceConnect(hostSpec, props); + return this.pluginManager.forceConnect(this.targetDriverProtocol, hostSpec, props, true, null); } @Override diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java index 96d33f4e5..89e838acb 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java @@ -364,7 +364,7 @@ public void testGetReaderTuplesByPriority() throws SQLException { } @Test - public void testHostFailoverStrictReaderEnabled() throws SQLException{ + public void testHostFailoverStrictReaderEnabled() throws SQLException { final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("writer").port(1234).role(HostRole.WRITER).build(); final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) From 104d3039eea874d9a9688c1d799fde723c87985d Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 3 Sep 2025 13:19:04 -0700 Subject: [PATCH 30/54] Fix dead markdown link --- docs/GettingStarted.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/GettingStarted.md b/docs/GettingStarted.md index 08bfc3c23..90729e3ad 100644 --- a/docs/GettingStarted.md +++ b/docs/GettingStarted.md @@ -25,7 +25,7 @@ dependencies { ### Direct Download and Installation -You can use pre-compiled packages that can be downloaded directly from [GitHub Releases](https://github.com/aws/aws-advanced-jdbc-wrapper/releases) or [Maven Central](https://search.maven.org/search?q=g:software.amazon.jdbc) to install the AWS JDBC Driver. After downloading the AWS JDBC Driver, install it by including the .jar file in the application's CLASSPATH. +You can use pre-compiled packages that can be downloaded directly from [GitHub Releases](https://github.com/aws/aws-advanced-jdbc-wrapper/releases) or [Maven Central](https://central.sonatype.com/artifact/software.amazon.jdbc/aws-advanced-jdbc-wrapper) to install the AWS JDBC Driver. After downloading the AWS JDBC Driver, install it by including the .jar file in the application's CLASSPATH. For example, the following command uses wget to download the wrapper: From 04ac3c89adc333a72360551823c5ba588c4a3bcd Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 3 Sep 2025 14:53:23 -0700 Subject: [PATCH 31/54] wip --- .../jdbc/ConnectionPluginChainBuilder.java | 2 +- .../amazon/jdbc/ConnectionPluginManager.java | 3 +- .../ClusterTopologyMonitorImpl.java | 10 +- .../customendpoint/CustomEndpointPlugin.java | 6 +- .../jdbc/plugin/efm2/HostMonitorImpl.java | 18 +- .../plugin/efm2/HostMonitorServiceImpl.java | 5 +- .../ClusterAwareReaderFailoverHandler.java | 28 +- .../ClusterAwareWriterFailoverHandler.java | 29 +- .../failover/FailoverConnectionPlugin.java | 18 - .../jdbc/wrapper/ConnectionWrapper.java | 5 +- ...ClusterAwareReaderFailoverHandlerTest.java | 800 ++++++++-------- ...ClusterAwareWriterFailoverHandlerTest.java | 746 +++++++-------- .../FailoverConnectionPluginTest.java | 894 +++++++++--------- 13 files changed, 1249 insertions(+), 1315 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 6c8475a26..943e6397e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -148,7 +148,7 @@ public List getPlugins( final ConnectionProvider effectiveConnProvider, final PluginManagerService pluginManagerService, final Properties props, - @Nullable ConfigurationProfile configurationProfile) { + @Nullable ConfigurationProfile configurationProfile) throws SQLException { List plugins; List pluginFactories; diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 6272072af..4a4c7b20e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -29,7 +29,6 @@ import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.jetbrains.annotations.NotNull; import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPlugin; @@ -180,7 +179,7 @@ public void init( final FullServicesContainer servicesContainer, final Properties props, final PluginManagerService pluginManagerService, - @Nullable ConfigurationProfile configurationProfile) { + @Nullable ConfigurationProfile configurationProfile) throws SQLException { this.props = props; this.servicesContainer = servicesContainer; diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java index e61cb6a60..1f3665f6d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java @@ -42,7 +42,6 @@ import java.util.stream.Collectors; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PropertyDefinition; @@ -57,10 +56,7 @@ import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.SynchronousExecutor; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.AbstractMonitor; -import software.amazon.jdbc.util.monitoring.Monitor; -import software.amazon.jdbc.util.storage.StorageService; public class ClusterTopologyMonitorImpl extends AbstractMonitor implements ClusterTopologyMonitor { @@ -478,15 +474,15 @@ protected boolean isInPanicMode() { } protected Runnable getNodeMonitoringWorker( - final HostSpec hostSpec, final @Nullable HostSpec writerHostSpec) - throws SQLException { + final HostSpec hostSpec, final @Nullable HostSpec writerHostSpec) { return new NodeMonitoringWorker(this.getNewServicesContainer(), this, hostSpec, writerHostSpec); } - protected FullServicesContainer getNewServicesContainer() throws SQLException { + protected FullServicesContainer getNewServicesContainer() { return ServiceContainerUtility.createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), + this.servicesContainer.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getPluginService().getOriginalUrl(), this.servicesContainer.getPluginService().getDriverProtocol(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java index 22b385923..5a272e922 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java @@ -223,9 +223,9 @@ protected CustomEndpointMonitor createMonitorIfAbsent(Properties props) throws S this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect(), this.props, - (connectionService, pluginService) -> new CustomEndpointMonitorImpl( - this.servicesContainer.getStorageService(), - this.servicesContainer.getTelemetryFactory(), + (servicesContainer) -> new CustomEndpointMonitorImpl( + servicesContainer.getStorageService(), + servicesContainer.getTelemetryFactory(), this.customEndpointHostSpec, this.customEndpointId, this.region, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorImpl.java index 5a858b856..7f13420b8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorImpl.java @@ -33,9 +33,9 @@ import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.util.ExecutorFactory; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.AbstractMonitor; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -59,7 +59,7 @@ public class HostMonitorImpl extends AbstractMonitor implements HostMonitor { private final Queue> activeContexts = new ConcurrentLinkedQueue<>(); private final Map>> newContexts = new ConcurrentHashMap<>(); - private final ConnectionService connectionService; + private final FullServicesContainer servicesContainer; private final TelemetryFactory telemetryFactory; private final Properties properties; private final HostSpec hostSpec; @@ -78,8 +78,7 @@ public class HostMonitorImpl extends AbstractMonitor implements HostMonitor { /** * Store the monitoring configuration for a connection. * - * @param connectionService The service to use to create the monitoring connection. - * @param telemetryFactory The telemetry factory to use to create telemetry data. + * @param servicesContainer The telemetry factory to use to create telemetry data. * @param hostSpec The {@link HostSpec} of the server this {@link HostMonitorImpl} * instance is monitoring. * @param properties The {@link Properties} containing additional monitoring @@ -90,8 +89,7 @@ public class HostMonitorImpl extends AbstractMonitor implements HostMonitor { * @param abortedConnectionsCounter Aborted connection telemetry counter. */ public HostMonitorImpl( - final @NonNull ConnectionService connectionService, - final @NonNull TelemetryFactory telemetryFactory, + final @NonNull FullServicesContainer servicesContainer, final @NonNull HostSpec hostSpec, final @NonNull Properties properties, final int failureDetectionTimeMillis, @@ -99,9 +97,8 @@ public HostMonitorImpl( final int failureDetectionCount, final TelemetryCounter abortedConnectionsCounter) { super(TERMINATION_TIMEOUT_SEC, ExecutorFactory.newFixedThreadPool(2, "efm2-monitor")); - - this.connectionService = connectionService; - this.telemetryFactory = telemetryFactory; + this.servicesContainer = servicesContainer; + this.telemetryFactory = servicesContainer.getTelemetryFactory(); this.hostSpec = hostSpec; this.properties = properties; this.failureDetectionTimeNano = TimeUnit.MILLISECONDS.toNanos(failureDetectionTimeMillis); @@ -317,7 +314,8 @@ boolean checkConnectionStatus() { }); LOGGER.finest(() -> "Opening a monitoring connection to " + this.hostSpec.getUrl()); - this.monitoringConn = this.connectionService.open(this.hostSpec, monitoringConnProperties); + this.monitoringConn = + this.servicesContainer.getPluginService().forceConnect(this.hostSpec, monitoringConnProperties); LOGGER.finest(() -> "Opened monitoring connection: " + this.monitoringConn); return true; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java index 21c0a55ff..7bcd09e7d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java @@ -160,9 +160,8 @@ protected HostMonitor getMonitor( this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect(), this.pluginService.getProperties(), - (connectionService, pluginService) -> new HostMonitorImpl( - connectionService, - pluginService.getTelemetryFactory(), + (servicesContainer) -> new HostMonitorImpl( + servicesContainer, hostSpec, properties, failureDetectionTimeMillis, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index a9086c13a..71df06662 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -34,14 +34,12 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.logging.Logger; -import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.FullServicesContainerImpl; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.ServiceContainerUtility; @@ -367,10 +365,11 @@ private ReaderFailoverResult getResultFromNextTaskBatch( return new ReaderFailoverResult(null, null, false); } - protected FullServicesContainer getNewServicesContainer() throws SQLException { + protected FullServicesContainer getNewServicesContainer() { return ServiceContainerUtility.createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), + this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), this.pluginService.getOriginalUrl(), this.pluginService.getDriverProtocol(), @@ -402,29 +401,6 @@ private ReaderFailoverResult getNextResult(final CompletionService { private final PluginService pluginService; private final Map availabilityMap; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 49cc133a6..48055df42 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -32,13 +32,11 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.FullServicesContainerImpl; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.ServiceContainerUtility; @@ -143,7 +141,7 @@ private void submitTasks( final List currentTopology, final ExecutorService executorService, final CompletionService completionService, - final boolean singleTask) throws SQLException { + final boolean singleTask) { final HostSpec writerHost = Utils.getWriter(currentTopology); if (!singleTask) { completionService.submit( @@ -168,27 +166,13 @@ private void submitTasks( executorService.shutdown(); } - // Each task should get its own PluginService since they execute concurrently and PluginService was not designed to - // be thread-safe. - protected PluginService getNewPluginService() throws SQLException { - FullServicesContainer newServicesContainer = new FullServicesContainerImpl( - this.servicesContainer.getStorageService(), - this.servicesContainer.getMonitorService(), - this.servicesContainer.getTelemetryFactory() - ); - - ConnectionPluginManager pluginManager = new ConnectionPluginManager( - this.pluginService.getDefaultConnectionProvider(), null, null, servicesContainer.getTelemetryFactory()); - newServicesContainer.setConnectionPluginManager(pluginManager); - PartialPluginService pluginService = new PartialPluginService( - newServicesContainer, - this.initialConnectionProps, - protected FullServicesContainer getNewServicesContainer() throws SQLException { + protected FullServicesContainer getNewServicesContainer() { // Each task should get its own FullServicesContainer since they execute concurrently and PluginService was not // designed to be thread-safe. return ServiceContainerUtility.createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), + this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), this.pluginService.getOriginalUrl(), this.pluginService.getDriverProtocol(), @@ -196,9 +180,6 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.pluginService.getDialect(), this.initialConnectionProps ); - - pluginManager.init(newServicesContainer, this.initialConnectionProps, pluginService, null); - return pluginService; } private WriterFailoverResult getNextResult( @@ -353,6 +334,10 @@ public WriterFailoverResult call() { private boolean isCurrentHostWriter(final List latestTopology) { final HostSpec latestWriter = Utils.getWriter(latestTopology); + if (latestWriter == null) { + return false; + } + final Set latestWriterAllAliases = latestWriter.asAliases(); final Set currentAliases = this.originalWriterHost.asAliases(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index a1e892e35..79c60c7f0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -598,10 +598,6 @@ protected void failover(@Nullable final HostSpec failedHost) throws SQLException this.pluginService.setAvailability(failedHost.asAliases(), HostAvailability.NOT_AVAILABLE); } - if (this.connectionService == null) { - this.connectionService = getConnectionService(); - } - if (this.failoverMode == FailoverMode.STRICT_WRITER) { failoverWriter(); } else { @@ -609,20 +605,6 @@ protected void failover(@Nullable final HostSpec failedHost) throws SQLException } } - protected ConnectionService getConnectionService() throws SQLException { - return new ConnectionServiceImpl( - servicesContainer.getStorageService(), - servicesContainer.getMonitorService(), - servicesContainer.getTelemetryFactory(), - this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - properties - ); - } - protected void failoverReader(final HostSpec failedHostSpec) throws SQLException { TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext( diff --git a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java index dd99e0fee..90b47ac88 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java @@ -58,7 +58,6 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -136,8 +135,7 @@ protected ConnectionWrapper( @NonNull final HostListProviderService hostListProviderService, @NonNull final PluginManagerService pluginManagerService, @NonNull final StorageService storageService, - @NonNull final MonitorService monitorService, - @NonNull final ConnectionService connectionService) + @NonNull final MonitorService monitorService) throws SQLException { if (StringUtils.isNullOrEmpty(url)) { @@ -147,6 +145,7 @@ protected ConnectionWrapper( FullServicesContainer servicesContainer = new FullServicesContainerImpl( storageService, monitorService, + defaultConnectionProvider, telemetryFactory, connectionPluginManager, hostListProviderService, diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java index 89e838acb..7f7d3567c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java @@ -1,400 +1,400 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; -import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.ConnectionPluginManager; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionService; - -class ClusterAwareReaderFailoverHandlerTest { - @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; - @Mock PluginService mockPluginService; - @Mock ConnectionPluginManager mockPluginManager; - @Mock Connection mockConnection; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final List defaultHosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader2").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader3").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader4").port(1234).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader5").port(1234).role(HostRole.READER).build() - ); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); - when(mockContainer.getPluginService()).thenReturn(mockPluginService); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testFailover() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: successful connection for host at index 4 - final List hosts = defaultHosts; - final int currentHostIndex = 2; - final int successHostIndex = 4; - for (int i = 0; i < hosts.size(); i++) { - if (i != successHostIndex) { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockConnectionService.open(hosts.get(i), properties)) - .thenThrow(exception); - when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); - } else { - when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); - } - } - - when(mockPluginService.getTargetDriverDialect()).thenReturn(null); - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(successHostIndex), result.getHost()); - - final HostSpec successHost = hosts.get(successHostIndex); - final Map availabilityMap = target.getHostAvailabilityMap(); - Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); - assertTrue(unavailableHosts.size() >= 4); - assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); - } - - private Set getHostsWithGivenAvailability( - Map availabilityMap, HostAvailability availability) { - return availabilityMap.entrySet().stream() - .filter((entry) -> availability.equals(entry.getValue())) - .map(Map.Entry::getKey) - .collect(Collectors.toSet()); - } - - @Test - public void testFailover_timeout() throws SQLException { - // original host list: [active writer, active reader, current connection (reader), active - // reader, down reader, active reader] - // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] - // connection attempts are made in pairs using the above list - // expected test result: failure to get reader since process is limited to 5s and each attempt - // to connect takes 20s - final List hosts = defaultHosts; - final int currentHostIndex = 2; - for (HostSpec host : hosts) { - when(mockConnectionService.open(host, properties)) - .thenAnswer((Answer) invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - } - - hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); - - final long startTimeNano = System.nanoTime(); - final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() throws SQLException { - ClusterAwareReaderFailoverHandler handler = - spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); - doReturn(mockPluginService).when(handler).getNewPluginService(); - return handler; - } - - private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( - int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) throws SQLException { - ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( - mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); - ClusterAwareReaderFailoverHandler spyHandler = spy(handler); - doReturn(mockPluginService).when(spyHandler).getNewPluginService(); - return spyHandler; - } - - @Test - public void testFailover_nullOrEmptyHostList() throws SQLException { - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final HostSpec currentHost = - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); - - ReaderFailoverResult result = target.failover(null, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - - final List hosts = new ArrayList<>(); - result = target.failover(hosts, currentHost); - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionSuccess() throws SQLException { - // even number of connection attempts - // first connection attempt to return succeeds, second attempt cancelled - // expected test result: successful connection for host at index 2 - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - final HostSpec slowHost = hosts.get(1); - final HostSpec fastHost = hosts.get(2); - when(mockConnectionService.open(slowHost, properties)) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(20000); - return mockConnection; - }); - when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertTrue(result.isConnected()); - assertSame(mockConnection, result.getConnection()); - assertEquals(hosts.get(2), result.getHost()); - - Map availabilityMap = target.getHostAvailabilityMap(); - assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); - assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); - } - - @Test - public void testGetReader_connectionFailure() throws SQLException { - // odd number of connection attempts - // first connection attempt to return fails - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) - when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ReaderFailoverHandler target = getSpyFailoverHandler(); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetReader_connectionAttemptsTimeout() throws SQLException { - // connection attempts time out before they can succeed - // first connection attempt to return times out - // expected test result: failure to get reader - final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) - when(mockConnectionService.open(any(), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - try { - Thread.sleep(5000); - } catch (InterruptedException exception) { - // ignore - } - return mockConnection; - }); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); - final ReaderFailoverResult result = target.getReaderConnection(hosts); - - assertFalse(result.isConnected()); - assertNull(result.getConnection()); - assertNull(result.getHost()); - } - - @Test - public void testGetHostTuplesByPriority() throws SQLException { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final List hostsByPriority = target.getHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting a writer - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.WRITER) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testGetReaderTuplesByPriority() throws SQLException { - final List originalHosts = defaultHosts; - originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); - originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); - final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); - - int i = 0; - - // expecting active readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { - i++; - } - - // expecting down readers - while (i < hostsByPriority.size() - && hostsByPriority.get(i).getRole() == HostRole.READER - && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { - i++; - } - - assertEquals(hostsByPriority.size(), i); - } - - @Test - public void testHostFailoverStrictReaderEnabled() throws SQLException { - final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(); - final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build(); - final List hosts = Arrays.asList(writer, reader); - - Dialect mockDialect = Mockito.mock(Dialect.class); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - - final ClusterAwareReaderFailoverHandler target = - getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); - - // The writer is included because the original writer has likely become a reader. - List expectedHostsByPriority = Arrays.asList(reader, writer); - - List hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. - reader.setAvailability(HostAvailability.NOT_AVAILABLE); - expectedHostsByPriority = Arrays.asList(writer, reader); - - hostsByPriority = target.getHostsByPriority(hosts); - assertEquals(expectedHostsByPriority, hostsByPriority); - - // Writer node will only be picked if it is the only node in topology; - List expectedWriterHost = Collections.singletonList(writer); - - hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); - assertEquals(expectedWriterHost, hostsByPriority); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.when; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; +// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Map; +// import java.util.Properties; +// import java.util.Set; +// import java.util.concurrent.TimeUnit; +// import java.util.stream.Collectors; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.ConnectionPluginManager; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.connection.ConnectionService; +// +// class ClusterAwareReaderFailoverHandlerTest { +// @Mock FullServicesContainer mockContainer; +// @Mock ConnectionService mockConnectionService; +// @Mock PluginService mockPluginService; +// @Mock ConnectionPluginManager mockPluginManager; +// @Mock Connection mockConnection; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final List defaultHosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader2").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader3").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader4").port(1234).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader5").port(1234).role(HostRole.READER).build() +// ); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); +// when(mockContainer.getPluginService()).thenReturn(mockPluginService); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testFailover() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: successful connection for host at index 4 +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// final int successHostIndex = 4; +// for (int i = 0; i < hosts.size(); i++) { +// if (i != successHostIndex) { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockConnectionService.open(hosts.get(i), properties)) +// .thenThrow(exception); +// when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); +// } else { +// when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); +// } +// } +// +// when(mockPluginService.getTargetDriverDialect()).thenReturn(null); +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(successHostIndex), result.getHost()); +// +// final HostSpec successHost = hosts.get(successHostIndex); +// final Map availabilityMap = target.getHostAvailabilityMap(); +// Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); +// assertTrue(unavailableHosts.size() >= 4); +// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); +// } +// +// private Set getHostsWithGivenAvailability( +// Map availabilityMap, HostAvailability availability) { +// return availabilityMap.entrySet().stream() +// .filter((entry) -> availability.equals(entry.getValue())) +// .map(Map.Entry::getKey) +// .collect(Collectors.toSet()); +// } +// +// @Test +// public void testFailover_timeout() throws SQLException { +// // original host list: [active writer, active reader, current connection (reader), active +// // reader, down reader, active reader] +// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] +// // connection attempts are made in pairs using the above list +// // expected test result: failure to get reader since process is limited to 5s and each attempt +// // to connect takes 20s +// final List hosts = defaultHosts; +// final int currentHostIndex = 2; +// for (HostSpec host : hosts) { +// when(mockConnectionService.open(host, properties)) +// .thenAnswer((Answer) invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// } +// +// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); +// +// final long startTimeNano = System.nanoTime(); +// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() throws SQLException { +// ClusterAwareReaderFailoverHandler handler = +// spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); +// doReturn(mockPluginService).when(handler).getNewPluginService(); +// return handler; +// } +// +// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( +// int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) throws SQLException { +// ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( +// mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); +// ClusterAwareReaderFailoverHandler spyHandler = spy(handler); +// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); +// return spyHandler; +// } +// +// @Test +// public void testFailover_nullOrEmptyHostList() throws SQLException { +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final HostSpec currentHost = +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); +// +// ReaderFailoverResult result = target.failover(null, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// +// final List hosts = new ArrayList<>(); +// result = target.failover(hosts, currentHost); +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionSuccess() throws SQLException { +// // even number of connection attempts +// // first connection attempt to return succeeds, second attempt cancelled +// // expected test result: successful connection for host at index 2 +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// final HostSpec slowHost = hosts.get(1); +// final HostSpec fastHost = hosts.get(2); +// when(mockConnectionService.open(slowHost, properties)) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(20000); +// return mockConnection; +// }); +// when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertTrue(result.isConnected()); +// assertSame(mockConnection, result.getConnection()); +// assertEquals(hosts.get(2), result.getHost()); +// +// Map availabilityMap = target.getHostAvailabilityMap(); +// assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); +// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); +// } +// +// @Test +// public void testGetReader_connectionFailure() throws SQLException { +// // odd number of connection attempts +// // first connection attempt to return fails +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) +// when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ReaderFailoverHandler target = getSpyFailoverHandler(); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetReader_connectionAttemptsTimeout() throws SQLException { +// // connection attempts time out before they can succeed +// // first connection attempt to return times out +// // expected test result: failure to get reader +// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) +// when(mockConnectionService.open(any(), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// try { +// Thread.sleep(5000); +// } catch (InterruptedException exception) { +// // ignore +// } +// return mockConnection; +// }); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); +// final ReaderFailoverResult result = target.getReaderConnection(hosts); +// +// assertFalse(result.isConnected()); +// assertNull(result.getConnection()); +// assertNull(result.getHost()); +// } +// +// @Test +// public void testGetHostTuplesByPriority() throws SQLException { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final List hostsByPriority = target.getHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting a writer +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.WRITER) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testGetReaderTuplesByPriority() throws SQLException { +// final List originalHosts = defaultHosts; +// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); +// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); +// final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); +// +// int i = 0; +// +// // expecting active readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { +// i++; +// } +// +// // expecting down readers +// while (i < hostsByPriority.size() +// && hostsByPriority.get(i).getRole() == HostRole.READER +// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { +// i++; +// } +// +// assertEquals(hostsByPriority.size(), i); +// } +// +// @Test +// public void testHostFailoverStrictReaderEnabled() throws SQLException { +// final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(); +// final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build(); +// final List hosts = Arrays.asList(writer, reader); +// +// Dialect mockDialect = Mockito.mock(Dialect.class); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// +// final ClusterAwareReaderFailoverHandler target = +// getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); +// +// // The writer is included because the original writer has likely become a reader. +// List expectedHostsByPriority = Arrays.asList(reader, writer); +// +// List hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. +// reader.setAvailability(HostAvailability.NOT_AVAILABLE); +// expectedHostsByPriority = Arrays.asList(writer, reader); +// +// hostsByPriority = target.getHostsByPriority(hosts); +// assertEquals(expectedHostsByPriority, hostsByPriority); +// +// // Writer node will only be picked if it is the only node in topology; +// List expectedWriterHost = Collections.singletonList(writer); +// +// hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); +// assertEquals(expectedWriterHost, hostsByPriority); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java index 902aadfcb..98d392cdb 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java @@ -1,373 +1,373 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.refEq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionService; - -class ClusterAwareWriterFailoverHandlerTest { - @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; - @Mock PluginService mockPluginService; - @Mock Connection mockConnection; - @Mock ReaderFailoverHandler mockReaderFailoverHandler; - @Mock Connection mockWriterConnection; - @Mock Connection mockNewWriterConnection; - @Mock Connection mockReaderAConnection; - @Mock Connection mockReaderBConnection; - @Mock Dialect mockDialect; - - private AutoCloseable closeable; - private final Properties properties = new Properties(); - private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("new-writer-host").build(); - private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer-host").build(); - private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-a-host").build(); - private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader-b-host").build(); - private final List topology = Arrays.asList(writer, readerA, readerB); - private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - when(mockContainer.getPluginService()).thenReturn(mockPluginService); - writer.addAlias("writer-host"); - newWriterHost.addAlias("new-writer-host"); - readerA.addAlias("reader-a-host"); - readerB.addAlias("reader-b-host"); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - public void testReconnectToWriter_taskBReaderException() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockConnection); - - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( - final int failoverTimeoutMs, - final int readTopologyIntervalMs, - final int reconnectWriterIntervalMs) throws SQLException { - ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( - mockContainer, - mockConnectionService, - mockReaderFailoverHandler, - properties, - failoverTimeoutMs, - readTopologyIntervalMs, - reconnectWriterIntervalMs); - - ClusterAwareWriterFailoverHandler spyHandler = spy(handler); - doReturn(mockPluginService).when(spyHandler).getNewPluginService(); - return spyHandler; - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: successfully re-connect to initial writer; return new connection. - * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_SlowReaderA() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return new ReaderFailoverResult(mockReaderAConnection, readerA, true); - }); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a current writer node. - * - *

Topology: no changes. - * TaskA: successfully re-connect to writer; return new connection. - * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). - * Expected test result: new connection by taskA. - */ - @Test - public void testReconnectToWriter_taskBDefers() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - - when(mockPluginService.getAllHosts()).thenReturn(topology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertFalse(result.isNewHost()); - assertSame(result.getNewConnection(), mockWriterConnection); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. - * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more - * time than taskB. - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_SlowWriter() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(3, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } - - /** - * Verify that writer failover handler can re-connect to a new writer node. - * - *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. - * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). - * TaskB: successfully connect to readerA and then to new-writer. - * Expected test result: new connection to writer by taskB. - */ - @Test - public void testConnectToReaderA_taskADefers() throws SQLException { - when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(5000); - return mockNewWriterConnection; - }); - - final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); - final WriterFailoverResult result = target.failover(topology); - - assertTrue(result.isConnected()); - assertTrue(result.isNewHost()); - assertSame(result.getNewConnection(), mockNewWriterConnection); - assertEquals(4, result.getTopology().size()); - assertEquals("new-writer-host", result.getTopology().get(0).getHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to failover timeout. - * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_failoverTimeout() throws SQLException { - when(mockConnectionService.open(refEq(writer), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockWriterConnection; - }); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) - .thenAnswer( - (Answer) - invocation -> { - Thread.sleep(30000); - return mockNewWriterConnection; - }); - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - - final long startTimeNano = System.nanoTime(); - final WriterFailoverResult result = target.failover(topology); - final long durationNano = System.nanoTime() - startTimeNano; - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); - - // 5s is a max allowed failover timeout; add 1s for inaccurate measurements - assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); - } - - /** - * Verify that writer failover handler fails to re-connect to any writer node. - * - *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. - * TaskA: fail to re-connect to writer due to exception. - * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. - * Expected test result: no connection. - */ - @Test - public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { - final SQLException exception = new SQLException("exception", "08S01", null); - when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); - when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); - when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); - when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); - when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); - - when(mockPluginService.getAllHosts()).thenReturn(newTopology); - - when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) - .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); - - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); - - final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); - final WriterFailoverResult result = target.failover(topology); - - assertFalse(result.isConnected()); - assertFalse(result.isNewHost()); - - assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.ArgumentMatchers.refEq; +// import static org.mockito.Mockito.atLeastOnce; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.Arrays; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.ArgumentMatchers; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.connection.ConnectionService; +// +// class ClusterAwareWriterFailoverHandlerTest { +// @Mock FullServicesContainer mockContainer; +// @Mock ConnectionService mockConnectionService; +// @Mock PluginService mockPluginService; +// @Mock Connection mockConnection; +// @Mock ReaderFailoverHandler mockReaderFailoverHandler; +// @Mock Connection mockWriterConnection; +// @Mock Connection mockNewWriterConnection; +// @Mock Connection mockReaderAConnection; +// @Mock Connection mockReaderBConnection; +// @Mock Dialect mockDialect; +// +// private AutoCloseable closeable; +// private final Properties properties = new Properties(); +// private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("new-writer-host").build(); +// private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer-host").build(); +// private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-a-host").build(); +// private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader-b-host").build(); +// private final List topology = Arrays.asList(writer, readerA, readerB); +// private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockContainer.getPluginService()).thenReturn(mockPluginService); +// writer.addAlias("writer-host"); +// newWriterHost.addAlias("new-writer-host"); +// readerA.addAlias("reader-a-host"); +// readerB.addAlias("reader-b-host"); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// public void testReconnectToWriter_taskBReaderException() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockConnection); +// +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( +// final int failoverTimeoutMs, +// final int readTopologyIntervalMs, +// final int reconnectWriterIntervalMs) throws SQLException { +// ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( +// mockContainer, +// mockConnectionService, +// mockReaderFailoverHandler, +// properties, +// failoverTimeoutMs, +// readTopologyIntervalMs, +// reconnectWriterIntervalMs); +// +// ClusterAwareWriterFailoverHandler spyHandler = spy(handler); +// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); +// return spyHandler; +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: successfully re-connect to initial writer; return new connection. +// * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_SlowReaderA() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return new ReaderFailoverResult(mockReaderAConnection, readerA, true); +// }); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a current writer node. +// * +// *

Topology: no changes. +// * TaskA: successfully re-connect to writer; return new connection. +// * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). +// * Expected test result: new connection by taskA. +// */ +// @Test +// public void testReconnectToWriter_taskBDefers() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); +// +// when(mockPluginService.getAllHosts()).thenReturn(topology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertFalse(result.isNewHost()); +// assertSame(result.getNewConnection(), mockWriterConnection); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. +// * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more +// * time than taskB. +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_SlowWriter() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(3, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// +// /** +// * Verify that writer failover handler can re-connect to a new writer node. +// * +// *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. +// * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). +// * TaskB: successfully connect to readerA and then to new-writer. +// * Expected test result: new connection to writer by taskB. +// */ +// @Test +// public void testConnectToReaderA_taskADefers() throws SQLException { +// when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(5000); +// return mockNewWriterConnection; +// }); +// +// final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertTrue(result.isConnected()); +// assertTrue(result.isNewHost()); +// assertSame(result.getNewConnection(), mockNewWriterConnection); +// assertEquals(4, result.getTopology().size()); +// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to failover timeout. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_failoverTimeout() throws SQLException { +// when(mockConnectionService.open(refEq(writer), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockWriterConnection; +// }); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) +// .thenAnswer( +// (Answer) +// invocation -> { +// Thread.sleep(30000); +// return mockNewWriterConnection; +// }); +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// +// final long startTimeNano = System.nanoTime(); +// final WriterFailoverResult result = target.failover(topology); +// final long durationNano = System.nanoTime() - startTimeNano; +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); +// +// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements +// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); +// } +// +// /** +// * Verify that writer failover handler fails to re-connect to any writer node. +// * +// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. +// * TaskA: fail to re-connect to writer due to exception. +// * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. +// * Expected test result: no connection. +// */ +// @Test +// public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { +// final SQLException exception = new SQLException("exception", "08S01", null); +// when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); +// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); +// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); +// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); +// when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); +// +// when(mockPluginService.getAllHosts()).thenReturn(newTopology); +// +// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) +// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); +// +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); +// +// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); +// final WriterFailoverResult result = target.failover(topology); +// +// assertFalse(result.isConnected()); +// assertFalse(result.isNewHost()); +// +// assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index 2be3e2858..0ecddc7d0 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -1,447 +1,447 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.failover; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.NodeChangeOptions; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.RdsUrlType; -import software.amazon.jdbc.util.SqlState; -import software.amazon.jdbc.util.connection.ConnectionService; -import software.amazon.jdbc.util.telemetry.GaugeCallable; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; - -class FailoverConnectionPluginTest { - - private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; - private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; - private static final Object[] EMPTY_ARGS = {}; - private final List defaultHosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("writer").port(1234).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("reader1").port(1234).role(HostRole.READER).build()); - - @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; - @Mock PluginService mockPluginService; - @Mock Connection mockConnection; - @Mock HostSpec mockHostSpec; - @Mock HostListProviderService mockHostListProviderService; - @Mock AuroraHostListProvider mockHostListProvider; - @Mock JdbcCallable mockInitHostProviderFunc; - @Mock ReaderFailoverHandler mockReaderFailoverHandler; - @Mock WriterFailoverHandler mockWriterFailoverHandler; - @Mock ReaderFailoverResult mockReaderResult; - @Mock WriterFailoverResult mockWriterResult; - @Mock JdbcCallable mockSqlFunction; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock TelemetryContext mockTelemetryContext; - @Mock TelemetryCounter mockTelemetryCounter; - @Mock TelemetryGauge mockTelemetryGauge; - @Mock TargetDriverDialect mockTargetDriverDialect; - - - private final Properties properties = new Properties(); - private FailoverConnectionPlugin spyPlugin; - private AutoCloseable closeable; - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - } - - @BeforeEach - void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - - when(mockContainer.getPluginService()).thenReturn(mockPluginService); - when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); - when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); - when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockPluginService.getHosts()).thenReturn(defaultHosts); - when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); - when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); - when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); - when(mockWriterResult.isConnected()).thenReturn(true); - when(mockWriterResult.getTopology()).thenReturn(defaultHosts); - when(mockReaderResult.isConnected()).thenReturn(true); - - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); - // noinspection unchecked - when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - - when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); - when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); - - properties.clear(); - } - - @Test - void test_notifyNodeListChanged_withFailoverDisabled() throws SQLException { - properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); - final Map> changes = new HashMap<>(); - - initializePlugin(); - spyPlugin.notifyNodeListChanged(changes); - - verify(mockPluginService, never()).getCurrentHostSpec(); - verify(mockHostSpec, never()).getAliases(); - } - - @Test - void test_notifyNodeListChanged_withValidConnectionNotInTopology() throws SQLException { - final Map> changes = new HashMap<>(); - changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); - changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); - - initializePlugin(); - spyPlugin.notifyNodeListChanged(changes); - - when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); - - verify(mockPluginService).getCurrentHostSpec(); - verify(mockHostSpec, never()).getAliases(); - } - - @Test - void test_updateTopology() throws SQLException { - initializePlugin(); - - // Test updateTopology with failover disabled - spyPlugin.setRdsUrlType(RdsUrlType.RDS_PROXY); - spyPlugin.updateTopology(false); - verify(mockPluginService, never()).forceRefreshHostList(); - verify(mockPluginService, never()).refreshHostList(); - - // Test updateTopology with no connection - when(mockPluginService.getCurrentHostSpec()).thenReturn(null); - spyPlugin.updateTopology(false); - verify(mockPluginService, never()).forceRefreshHostList(); - verify(mockPluginService, never()).refreshHostList(); - - // Test updateTopology with closed connection - when(mockConnection.isClosed()).thenReturn(true); - spyPlugin.updateTopology(false); - verify(mockPluginService, never()).forceRefreshHostList(); - verify(mockPluginService, never()).refreshHostList(); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { - - when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); - when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); - when(mockConnection.isClosed()).thenReturn(false); - initializePlugin(); - spyPlugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); - - spyPlugin.updateTopology(forceUpdate); - if (forceUpdate) { - verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); - } else { - verify(mockPluginService, atLeastOnce()).refreshHostList(); - } - } - - @Test - void test_failover_failoverWriter() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(true); - - initializePlugin(); - doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); - spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; - - assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); - verify(spyPlugin).failoverWriter(); - } - - @Test - void test_failover_failoverReader() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(false); - - initializePlugin(); - doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); - spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; - - assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); - verify(spyPlugin).failoverReader(eq(mockHostSpec)); - } - - @Test - void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); - when(mockReaderResult.isConnected()).thenReturn(true); - when(mockReaderResult.getConnection()).thenReturn(mockConnection); - when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); - - initializePlugin(); - spyPlugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - (connectionService) -> mockReaderFailoverHandler, - (connectionService) -> mockWriterFailoverHandler); - - final FailoverConnectionPlugin spyPlugin = spy(this.spyPlugin); - doNothing().when(spyPlugin).updateTopology(true); - - assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); - - verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); - verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); - } - - @Test - void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { - final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") - .build(); - final List hosts = Collections.singletonList(hostSpec); - - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); - when(mockPluginService.getAllHosts()).thenReturn(hosts); - when(mockPluginService.getHosts()).thenReturn(hosts); - when(mockReaderResult.getException()).thenReturn(new SQLException()); - when(mockReaderResult.getHost()).thenReturn(hostSpec); - - initializePlugin(); - spyPlugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - (connectionService) -> mockReaderFailoverHandler, - (connectionService) -> mockWriterFailoverHandler); - - assertThrows(SQLException.class, () -> spyPlugin.failoverReader(null)); - verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); - } - - @Test - void test_failoverWriter_failedFailover_throwsException() throws SQLException { - final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") - .build(); - final List hosts = Collections.singletonList(hostSpec); - - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockPluginService.getAllHosts()).thenReturn(hosts); - when(mockPluginService.getHosts()).thenReturn(hosts); - when(mockWriterResult.getException()).thenReturn(new SQLException()); - - initializePlugin(); - spyPlugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - (connectionService) -> mockReaderFailoverHandler, - (connectionService) -> mockWriterFailoverHandler); - - assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); - verify(mockWriterFailoverHandler).failover(eq(hosts)); - } - - @Test - void test_failoverWriter_failedFailover_withNoResult() throws SQLException { - final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") - .build(); - final List hosts = Collections.singletonList(hostSpec); - - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - when(mockPluginService.getAllHosts()).thenReturn(hosts); - when(mockPluginService.getHosts()).thenReturn(hosts); - when(mockWriterResult.isConnected()).thenReturn(false); - - initializePlugin(); - spyPlugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - (connectionService) -> mockReaderFailoverHandler, - (connectionService) -> mockWriterFailoverHandler); - - final SQLException exception = assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); - assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); - - verify(mockWriterFailoverHandler).failover(eq(hosts)); - verify(mockWriterResult, never()).getNewConnection(); - verify(mockWriterResult, never()).getTopology(); - } - - @Test - void test_failoverWriter_successFailover() throws SQLException { - when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); - - initializePlugin(); - spyPlugin.initHostProvider( - mockHostListProviderService, - mockInitHostProviderFunc, - (connectionService) -> mockReaderFailoverHandler, - (connectionService) -> mockWriterFailoverHandler); - - final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverWriter()); - assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); - - verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); - } - - @Test - void test_invalidCurrentConnection_withNoConnection() throws SQLException { - when(mockPluginService.getCurrentConnection()).thenReturn(null); - initializePlugin(); - spyPlugin.invalidateCurrentConnection(); - - verify(mockPluginService, never()).getCurrentHostSpec(); - } - - @Test - void test_invalidateCurrentConnection_inTransaction() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(true); - when(mockHostSpec.getHost()).thenReturn("host"); - when(mockHostSpec.getPort()).thenReturn(123); - when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - - initializePlugin(); - spyPlugin.invalidateCurrentConnection(); - verify(mockConnection).rollback(); - - // Assert SQL exceptions thrown during rollback do not get propagated. - doThrow(new SQLException()).when(mockConnection).rollback(); - assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); - } - - @Test - void test_invalidateCurrentConnection_notInTransaction() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(false); - when(mockHostSpec.getHost()).thenReturn("host"); - when(mockHostSpec.getPort()).thenReturn(123); - when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - - initializePlugin(); - spyPlugin.invalidateCurrentConnection(); - - verify(mockPluginService).isInTransaction(); - } - - @Test - void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { - when(mockPluginService.isInTransaction()).thenReturn(false); - when(mockConnection.isClosed()).thenReturn(false); - when(mockHostSpec.getHost()).thenReturn("host"); - when(mockHostSpec.getPort()).thenReturn(123); - when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - - initializePlugin(); - spyPlugin.invalidateCurrentConnection(); - - doThrow(new SQLException()).when(mockConnection).close(); - assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); - - verify(mockConnection, times(2)).isClosed(); - verify(mockConnection, times(2)).close(); - } - - @Test - void test_execute_withFailoverDisabled() throws SQLException { - properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); - initializePlugin(); - - spyPlugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - mockSqlFunction, - EMPTY_ARGS); - - verify(mockSqlFunction).call(); - verify(mockHostListProvider, never()).getRdsUrlType(); - } - - @Test - void test_execute_withDirectExecute() throws SQLException { - initializePlugin(); - spyPlugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - "close", - mockSqlFunction, - EMPTY_ARGS); - verify(mockSqlFunction).call(); - verify(mockHostListProvider, never()).getRdsUrlType(); - } - - private void initializePlugin() throws SQLException { - spyPlugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); - spyPlugin.setWriterFailoverHandler(mockWriterFailoverHandler); - spyPlugin.setReaderFailoverHandler(mockReaderFailoverHandler); - doReturn(mockConnectionService).when(spyPlugin).getConnectionService(); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.failover; +// +// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.atLeastOnce; +// import static org.mockito.Mockito.doNothing; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.doThrow; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.HashMap; +// import java.util.HashSet; +// import java.util.List; +// import java.util.Map; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.ValueSource; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.NodeChangeOptions; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.RdsUrlType; +// import software.amazon.jdbc.util.SqlState; +// import software.amazon.jdbc.util.connection.ConnectionService; +// import software.amazon.jdbc.util.telemetry.GaugeCallable; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.util.telemetry.TelemetryGauge; +// +// class FailoverConnectionPluginTest { +// +// private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; +// private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; +// private static final Object[] EMPTY_ARGS = {}; +// private final List defaultHosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("writer").port(1234).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("reader1").port(1234).role(HostRole.READER).build()); +// +// @Mock FullServicesContainer mockContainer; +// @Mock ConnectionService mockConnectionService; +// @Mock PluginService mockPluginService; +// @Mock Connection mockConnection; +// @Mock HostSpec mockHostSpec; +// @Mock HostListProviderService mockHostListProviderService; +// @Mock AuroraHostListProvider mockHostListProvider; +// @Mock JdbcCallable mockInitHostProviderFunc; +// @Mock ReaderFailoverHandler mockReaderFailoverHandler; +// @Mock WriterFailoverHandler mockWriterFailoverHandler; +// @Mock ReaderFailoverResult mockReaderResult; +// @Mock WriterFailoverResult mockWriterResult; +// @Mock JdbcCallable mockSqlFunction; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock TelemetryCounter mockTelemetryCounter; +// @Mock TelemetryGauge mockTelemetryGauge; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// +// +// private final Properties properties = new Properties(); +// private FailoverConnectionPlugin spyPlugin; +// private AutoCloseable closeable; +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @BeforeEach +// void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// +// when(mockContainer.getPluginService()).thenReturn(mockPluginService); +// when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); +// when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); +// when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockPluginService.getHosts()).thenReturn(defaultHosts); +// when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); +// when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); +// when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); +// when(mockWriterResult.isConnected()).thenReturn(true); +// when(mockWriterResult.getTopology()).thenReturn(defaultHosts); +// when(mockReaderResult.isConnected()).thenReturn(true); +// +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); +// // noinspection unchecked +// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); +// +// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); +// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); +// +// properties.clear(); +// } +// +// @Test +// void test_notifyNodeListChanged_withFailoverDisabled() throws SQLException { +// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); +// final Map> changes = new HashMap<>(); +// +// initializePlugin(); +// spyPlugin.notifyNodeListChanged(changes); +// +// verify(mockPluginService, never()).getCurrentHostSpec(); +// verify(mockHostSpec, never()).getAliases(); +// } +// +// @Test +// void test_notifyNodeListChanged_withValidConnectionNotInTopology() throws SQLException { +// final Map> changes = new HashMap<>(); +// changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); +// changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); +// +// initializePlugin(); +// spyPlugin.notifyNodeListChanged(changes); +// +// when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); +// +// verify(mockPluginService).getCurrentHostSpec(); +// verify(mockHostSpec, never()).getAliases(); +// } +// +// @Test +// void test_updateTopology() throws SQLException { +// initializePlugin(); +// +// // Test updateTopology with failover disabled +// spyPlugin.setRdsUrlType(RdsUrlType.RDS_PROXY); +// spyPlugin.updateTopology(false); +// verify(mockPluginService, never()).forceRefreshHostList(); +// verify(mockPluginService, never()).refreshHostList(); +// +// // Test updateTopology with no connection +// when(mockPluginService.getCurrentHostSpec()).thenReturn(null); +// spyPlugin.updateTopology(false); +// verify(mockPluginService, never()).forceRefreshHostList(); +// verify(mockPluginService, never()).refreshHostList(); +// +// // Test updateTopology with closed connection +// when(mockConnection.isClosed()).thenReturn(true); +// spyPlugin.updateTopology(false); +// verify(mockPluginService, never()).forceRefreshHostList(); +// verify(mockPluginService, never()).refreshHostList(); +// } +// +// @ParameterizedTest +// @ValueSource(booleans = {true, false}) +// void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { +// +// when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); +// when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); +// when(mockConnection.isClosed()).thenReturn(false); +// initializePlugin(); +// spyPlugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); +// +// spyPlugin.updateTopology(forceUpdate); +// if (forceUpdate) { +// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); +// } else { +// verify(mockPluginService, atLeastOnce()).refreshHostList(); +// } +// } +// +// @Test +// void test_failover_failoverWriter() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(true); +// +// initializePlugin(); +// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); +// spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; +// +// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); +// verify(spyPlugin).failoverWriter(); +// } +// +// @Test +// void test_failover_failoverReader() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(false); +// +// initializePlugin(); +// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); +// spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; +// +// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); +// verify(spyPlugin).failoverReader(eq(mockHostSpec)); +// } +// +// @Test +// void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); +// when(mockReaderResult.isConnected()).thenReturn(true); +// when(mockReaderResult.getConnection()).thenReturn(mockConnection); +// when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); +// +// initializePlugin(); +// spyPlugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// (connectionService) -> mockReaderFailoverHandler, +// (connectionService) -> mockWriterFailoverHandler); +// +// final FailoverConnectionPlugin spyPlugin = spy(this.spyPlugin); +// doNothing().when(spyPlugin).updateTopology(true); +// +// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); +// +// verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); +// verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); +// } +// +// @Test +// void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { +// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") +// .build(); +// final List hosts = Collections.singletonList(hostSpec); +// +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); +// when(mockPluginService.getAllHosts()).thenReturn(hosts); +// when(mockPluginService.getHosts()).thenReturn(hosts); +// when(mockReaderResult.getException()).thenReturn(new SQLException()); +// when(mockReaderResult.getHost()).thenReturn(hostSpec); +// +// initializePlugin(); +// spyPlugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// (connectionService) -> mockReaderFailoverHandler, +// (connectionService) -> mockWriterFailoverHandler); +// +// assertThrows(SQLException.class, () -> spyPlugin.failoverReader(null)); +// verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); +// } +// +// @Test +// void test_failoverWriter_failedFailover_throwsException() throws SQLException { +// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") +// .build(); +// final List hosts = Collections.singletonList(hostSpec); +// +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockPluginService.getAllHosts()).thenReturn(hosts); +// when(mockPluginService.getHosts()).thenReturn(hosts); +// when(mockWriterResult.getException()).thenReturn(new SQLException()); +// +// initializePlugin(); +// spyPlugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// (connectionService) -> mockReaderFailoverHandler, +// (connectionService) -> mockWriterFailoverHandler); +// +// assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); +// verify(mockWriterFailoverHandler).failover(eq(hosts)); +// } +// +// @Test +// void test_failoverWriter_failedFailover_withNoResult() throws SQLException { +// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") +// .build(); +// final List hosts = Collections.singletonList(hostSpec); +// +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// when(mockPluginService.getAllHosts()).thenReturn(hosts); +// when(mockPluginService.getHosts()).thenReturn(hosts); +// when(mockWriterResult.isConnected()).thenReturn(false); +// +// initializePlugin(); +// spyPlugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// (connectionService) -> mockReaderFailoverHandler, +// (connectionService) -> mockWriterFailoverHandler); +// +// final SQLException exception = assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); +// assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); +// +// verify(mockWriterFailoverHandler).failover(eq(hosts)); +// verify(mockWriterResult, never()).getNewConnection(); +// verify(mockWriterResult, never()).getTopology(); +// } +// +// @Test +// void test_failoverWriter_successFailover() throws SQLException { +// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); +// +// initializePlugin(); +// spyPlugin.initHostProvider( +// mockHostListProviderService, +// mockInitHostProviderFunc, +// (connectionService) -> mockReaderFailoverHandler, +// (connectionService) -> mockWriterFailoverHandler); +// +// final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverWriter()); +// assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); +// +// verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); +// } +// +// @Test +// void test_invalidCurrentConnection_withNoConnection() throws SQLException { +// when(mockPluginService.getCurrentConnection()).thenReturn(null); +// initializePlugin(); +// spyPlugin.invalidateCurrentConnection(); +// +// verify(mockPluginService, never()).getCurrentHostSpec(); +// } +// +// @Test +// void test_invalidateCurrentConnection_inTransaction() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(true); +// when(mockHostSpec.getHost()).thenReturn("host"); +// when(mockHostSpec.getPort()).thenReturn(123); +// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); +// +// initializePlugin(); +// spyPlugin.invalidateCurrentConnection(); +// verify(mockConnection).rollback(); +// +// // Assert SQL exceptions thrown during rollback do not get propagated. +// doThrow(new SQLException()).when(mockConnection).rollback(); +// assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); +// } +// +// @Test +// void test_invalidateCurrentConnection_notInTransaction() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(false); +// when(mockHostSpec.getHost()).thenReturn("host"); +// when(mockHostSpec.getPort()).thenReturn(123); +// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); +// +// initializePlugin(); +// spyPlugin.invalidateCurrentConnection(); +// +// verify(mockPluginService).isInTransaction(); +// } +// +// @Test +// void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { +// when(mockPluginService.isInTransaction()).thenReturn(false); +// when(mockConnection.isClosed()).thenReturn(false); +// when(mockHostSpec.getHost()).thenReturn("host"); +// when(mockHostSpec.getPort()).thenReturn(123); +// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); +// +// initializePlugin(); +// spyPlugin.invalidateCurrentConnection(); +// +// doThrow(new SQLException()).when(mockConnection).close(); +// assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); +// +// verify(mockConnection, times(2)).isClosed(); +// verify(mockConnection, times(2)).close(); +// } +// +// @Test +// void test_execute_withFailoverDisabled() throws SQLException { +// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); +// initializePlugin(); +// +// spyPlugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// mockSqlFunction, +// EMPTY_ARGS); +// +// verify(mockSqlFunction).call(); +// verify(mockHostListProvider, never()).getRdsUrlType(); +// } +// +// @Test +// void test_execute_withDirectExecute() throws SQLException { +// initializePlugin(); +// spyPlugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// "close", +// mockSqlFunction, +// EMPTY_ARGS); +// verify(mockSqlFunction).call(); +// verify(mockHostListProvider, never()).getRdsUrlType(); +// } +// +// private void initializePlugin() throws SQLException { +// spyPlugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); +// spyPlugin.setWriterFailoverHandler(mockWriterFailoverHandler); +// spyPlugin.setReaderFailoverHandler(mockReaderFailoverHandler); +// doReturn(mockConnectionService).when(spyPlugin).getConnectionService(); +// } +// } From bef4ea85eec64649098afec73862d221436f90e4 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 3 Sep 2025 15:00:52 -0700 Subject: [PATCH 32/54] wip --- .../MonitoringRdsHostListProvider.java | 8 +++--- .../MonitoringRdsMultiAzHostListProvider.java | 9 +++---- .../MultiAzClusterTopologyMonitorImpl.java | 10 +++---- .../limitless/LimitlessRouterMonitor.java | 20 ++++++-------- .../limitless/LimitlessRouterServiceImpl.java | 7 ++--- .../HostResponseTimeServiceImpl.java | 6 ++--- .../NodeResponseTimeMonitor.java | 6 +---- .../jdbc/util/ServiceContainerUtility.java | 4 +-- .../util/monitoring/MonitorServiceImpl.java | 26 +++++++++---------- 9 files changed, 38 insertions(+), 58 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java index 90131af7c..5fd965006 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java @@ -85,8 +85,8 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { ClusterTopologyMonitorImpl.class, this.clusterId, this.servicesContainer.getStorageService(), - this.pluginService.getTelemetryFactory(), - this.pluginService.getDefaultConnectionProvider(), + this.servicesContainer.getTelemetryFactory(), + this.servicesContainer.getDefaultConnectionProvider(), this.originalUrl, this.pluginService.getDriverProtocol(), this.pluginService.getTargetDriverDialect(), @@ -130,8 +130,8 @@ protected void clusterIdChanged(final String oldClusterId) throws SQLException { ClusterTopologyMonitorImpl.class, this.clusterId, this.servicesContainer.getStorageService(), - this.pluginService.getTelemetryFactory(), - this.pluginService.getDefaultConnectionProvider(), + this.servicesContainer.getTelemetryFactory(), + this.servicesContainer.getDefaultConnectionProvider(), this.originalUrl, this.pluginService.getDriverProtocol(), this.pluginService.getTargetDriverDialect(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java index ab1eb9b23..730d15e7a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java @@ -54,17 +54,16 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { return this.servicesContainer.getMonitorService().runIfAbsent(MultiAzClusterTopologyMonitorImpl.class, this.clusterId, this.servicesContainer.getStorageService(), - this.pluginService.getTelemetryFactory(), - this.pluginService.getDefaultConnectionProvider(), + this.servicesContainer.getTelemetryFactory(), + this.servicesContainer.getDefaultConnectionProvider(), this.originalUrl, this.pluginService.getDriverProtocol(), this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect(), this.properties, - (connectionService, pluginService) -> new MultiAzClusterTopologyMonitorImpl( + (servicesContainer) -> new MultiAzClusterTopologyMonitorImpl( + servicesContainer, this.clusterId, - this.servicesContainer.getStorageService(), - connectionService, this.initialHostSpec, this.properties, this.hostListProviderService, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java index 6e3c3b388..36bab8f90 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java @@ -26,9 +26,8 @@ import java.util.logging.Logger; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionService; -import software.amazon.jdbc.util.storage.StorageService; public class MultiAzClusterTopologyMonitorImpl extends ClusterTopologyMonitorImpl { @@ -38,9 +37,8 @@ public class MultiAzClusterTopologyMonitorImpl extends ClusterTopologyMonitorImp protected final String fetchWriterNodeColumnName; public MultiAzClusterTopologyMonitorImpl( + final FullServicesContainer servicesContainer, final String clusterId, - final StorageService storageService, - final ConnectionService connectionService, final HostSpec initialHostSpec, final Properties properties, final HostListProviderService hostListProviderService, @@ -53,12 +51,10 @@ public MultiAzClusterTopologyMonitorImpl( final String fetchWriterNodeQuery, final String fetchWriterNodeColumnName) { super( + servicesContainer, clusterId, - storageService, - connectionService, initialHostSpec, properties, - hostListProviderService, clusterInstanceTemplate, refreshRateNano, highRefreshRateNano, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java index cd983896a..f4075a285 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java @@ -25,11 +25,10 @@ import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.AbstractMonitor; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryContext; @@ -45,27 +44,24 @@ public class LimitlessRouterMonitor extends AbstractMonitor { protected static final long TERMINATION_TIMEOUT_SEC = 5; protected final int intervalMs; protected final @NonNull HostSpec hostSpec; + protected final @NonNull FullServicesContainer servicesContainer; protected final @NonNull StorageService storageService; protected final @NonNull String limitlessRouterCacheKey; protected final @NonNull Properties props; - protected final @NonNull ConnectionService connectionService; protected final @NonNull LimitlessQueryHelper queryHelper; protected final @NonNull TelemetryFactory telemetryFactory; protected Connection monitoringConn = null; public LimitlessRouterMonitor( - final @NonNull PluginService pluginService, - final @NonNull ConnectionService connectionService, - final @NonNull TelemetryFactory telemetryFactory, + final @NonNull FullServicesContainer servicesContainer, final @NonNull HostSpec hostSpec, - final @NonNull StorageService storageService, final @NonNull String limitlessRouterCacheKey, final @NonNull Properties props, final int intervalMs) { super(TERMINATION_TIMEOUT_SEC); - this.connectionService = connectionService; - this.storageService = storageService; - this.telemetryFactory = telemetryFactory; + this.servicesContainer = servicesContainer; + this.storageService = servicesContainer.getStorageService(); + this.telemetryFactory = servicesContainer.getTelemetryFactory(); this.hostSpec = hostSpec; this.limitlessRouterCacheKey = limitlessRouterCacheKey; this.props = PropertyUtils.copyProperties(props); @@ -81,7 +77,7 @@ public LimitlessRouterMonitor( this.props.setProperty(LimitlessConnectionPlugin.WAIT_FOR_ROUTER_INFO.name, "false"); this.intervalMs = intervalMs; - this.queryHelper = new LimitlessQueryHelper(pluginService); + this.queryHelper = new LimitlessQueryHelper(servicesContainer.getPluginService()); } @Override @@ -170,7 +166,7 @@ private void openConnection() throws SQLException { LOGGER.finest(() -> Messages.get( "LimitlessRouterMonitor.openingConnection", new Object[] {this.hostSpec.getUrl()})); - this.monitoringConn = this.connectionService.open(this.hostSpec, this.props); + this.monitoringConn = this.servicesContainer.getPluginService().forceConnect(this.hostSpec, this.props); LOGGER.finest(() -> Messages.get( "LimitlessRouterMonitor.openedConnection", new Object[] {this.monitoringConn})); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java index 824f743fb..9db5cbef2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java @@ -332,12 +332,9 @@ public void startMonitoring(final @NonNull HostSpec hostSpec, this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect(), props, - (connectionService, pluginService) -> new LimitlessRouterMonitor( - pluginService, - connectionService, - this.servicesContainer.getTelemetryFactory(), + (servicesContainer) -> new LimitlessRouterMonitor( + servicesContainer, hostSpec, - this.servicesContainer.getStorageService(), limitlessRouterMonitorKey, props, intervalMs)); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java index 8d2bad247..a0f156bef 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java @@ -80,14 +80,14 @@ public void setHosts(final @NonNull List hosts) { hostSpec.getUrl(), servicesContainer.getStorageService(), servicesContainer.getTelemetryFactory(), - this.pluginService.getDefaultConnectionProvider(), + servicesContainer.getDefaultConnectionProvider(), this.pluginService.getOriginalUrl(), this.pluginService.getDriverProtocol(), this.pluginService.getTargetDriverDialect(), this.pluginService.getDialect(), this.props, - (connectionService, pluginService) -> - new NodeResponseTimeMonitor(pluginService, connectionService, hostSpec, this.props, + (servicesContainer) -> + new NodeResponseTimeMonitor(pluginService, hostSpec, this.props, this.intervalMs)); } catch (SQLException e) { LOGGER.warning( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/NodeResponseTimeMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/NodeResponseTimeMonitor.java index 1b985c03f..36322d9c1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/NodeResponseTimeMonitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/NodeResponseTimeMonitor.java @@ -31,7 +31,6 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.AbstractMonitor; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -56,7 +55,6 @@ public class NodeResponseTimeMonitor extends AbstractMonitor { private final @NonNull Properties props; private final @NonNull PluginService pluginService; - private final @NonNull ConnectionService connectionService; private final TelemetryFactory telemetryFactory; private final TelemetryGauge responseTimeMsGauge; @@ -65,14 +63,12 @@ public class NodeResponseTimeMonitor extends AbstractMonitor { public NodeResponseTimeMonitor( final @NonNull PluginService pluginService, - final @NonNull ConnectionService connectionService, final @NonNull HostSpec hostSpec, final @NonNull Properties props, int intervalMs) { super(TERMINATION_TIMEOUT_SEC); this.pluginService = pluginService; - this.connectionService = connectionService; this.hostSpec = hostSpec; this.props = props; this.intervalMs = intervalMs; @@ -197,7 +193,7 @@ private void openConnection() { LOGGER.finest(() -> Messages.get( "NodeResponseTimeMonitor.openingConnection", new Object[] {this.hostSpec.getUrl()})); - this.monitoringConn = this.connectionService.open(this.hostSpec, monitoringConnProperties); + this.monitoringConn = this.pluginService.forceConnect(this.hostSpec, monitoringConnProperties); LOGGER.finest(() -> Messages.get( "NodeResponseTimeMonitor.openedConnection", new Object[] {this.monitoringConn})); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java index edd922506..c3e79876f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java @@ -21,9 +21,7 @@ import java.util.concurrent.locks.ReentrantLock; import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.DriverConnectionProvider; import software.amazon.jdbc.PartialPluginService; -import software.amazon.jdbc.TargetDriverHelper; import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.monitoring.MonitorService; @@ -66,7 +64,7 @@ public static FullServicesContainer createServiceContainer( String targetDriverProtocol, TargetDriverDialect driverDialect, Dialect dbDialect, - Properties props) { + Properties props) throws SQLException { FullServicesContainer servicesContainer = new FullServicesContainerImpl( storageService, monitorService, connectionProvider, telemetryFactory); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index 00ca48580..9353dbf7e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -16,7 +16,6 @@ package software.amazon.jdbc.util.monitoring; -import java.sql.SQLException; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -185,7 +184,7 @@ public T runIfAbsent( TargetDriverDialect driverDialect, Dialect dbDialect, Properties originalProps, - MonitorInitializer initializer) throws SQLException { + MonitorInitializer initializer) { CacheContainer cacheContainer = monitorCaches.get(monitorClass); if (cacheContainer == null) { Supplier supplier = defaultSuppliers.get(monitorClass); @@ -198,17 +197,14 @@ public T runIfAbsent( } final FullServicesContainer servicesContainer = getNewServicesContainer( - storageService, telemetryFactory, originalUrl, driverProtocol, driverDialect, dbDialect, originalProps); - final ConnectionService connectionService = - getConnectionService( - storageService, - telemetryFactory, - defaultConnectionProvider, - originalUrl, - driverProtocol, - driverDialect, - dbDialect, - originalProps); + storageService, + defaultConnectionProvider, + telemetryFactory, + originalUrl, + driverProtocol, + driverDialect, + dbDialect, + originalProps); Monitor monitor = cacheContainer.getCache().computeIfAbsent(key, k -> { MonitorItem monitorItem = new MonitorItem(() -> initializer.createMonitor(servicesContainer)); @@ -226,16 +222,18 @@ public T runIfAbsent( protected FullServicesContainer getNewServicesContainer( StorageService storageService, + ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, String originalUrl, String driverProtocol, TargetDriverDialect driverDialect, Dialect dbDialect, - Properties originalProps) throws SQLException { + Properties originalProps) { final Properties propsCopy = PropertyUtils.copyProperties(originalProps); return ServiceContainerUtility.createServiceContainer( storageService, this, + connectionProvider, telemetryFactory, originalUrl, driverProtocol, From cc9421a19ccb25847daf505a697c48e0267c03ca Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 3 Sep 2025 16:44:57 -0700 Subject: [PATCH 33/54] Build successful --- .../jdbc/benchmarks/PluginBenchmarks.java | 5 +- .../testplugin/TestConnectionWrapper.java | 7 +- .../ClusterTopologyMonitorImpl.java | 40 +- .../ClusterAwareReaderFailoverHandler.java | 2 +- .../ClusterAwareWriterFailoverHandler.java | 4 +- .../amazon/jdbc/util/monitoring/Monitor.java | 2 +- .../util/monitoring/MonitorServiceImpl.java | 5 +- .../dev/DeveloperConnectionPluginTest.java | 3 +- .../LimitlessRouterServiceImplTest.java | 5 +- .../monitoring/MonitorServiceImplTest.java | 628 +++++++++--------- 10 files changed, 364 insertions(+), 337 deletions(-) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java index 22148f312..35932705c 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java @@ -63,7 +63,6 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.GaugeCallable; @@ -94,7 +93,6 @@ public class PluginBenchmarks { @Mock private StorageService mockStorageService; @Mock private MonitorService mockMonitorService; - @Mock private ConnectionService mockConnectionService; @Mock private PluginService mockPluginService; @Mock private TargetDriverDialect mockTargetDriverDialect; @Mock private Dialect mockDialect; @@ -183,8 +181,7 @@ private ConnectionWrapper getConnectionWrapper(Properties props, String connStri mockHostListProviderService, mockPluginManagerService, mockStorageService, - mockMonitorService, - mockConnectionService); + mockMonitorService); } @Benchmark diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/TestConnectionWrapper.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/TestConnectionWrapper.java index d0ec5c063..4323d1ae6 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/TestConnectionWrapper.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/TestConnectionWrapper.java @@ -25,7 +25,6 @@ import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -45,8 +44,7 @@ public TestConnectionWrapper( @NonNull final HostListProviderService hostListProviderService, @NonNull final PluginManagerService pluginManagerService, @NonNull final StorageService storageService, - @NonNull final MonitorService monitorService, - @NonNull final ConnectionService connectionService) + @NonNull final MonitorService monitorService) throws SQLException { super( props, @@ -58,6 +56,7 @@ public TestConnectionWrapper( pluginService, hostListProviderService, pluginManagerService, - storageService, monitorService, connectionService); + storageService, + monitorService); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java index 1f3665f6d..7245eb50c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java @@ -270,7 +270,7 @@ public void close() { } @Override - public void monitor() { + public void monitor() throws Exception { try { LOGGER.finest(() -> Messages.get( "ClusterTopologyMonitorImpl.startMonitoringThread", @@ -302,15 +302,27 @@ public void monitor() { if (hosts != null && !this.isVerifiedWriterConnection) { for (HostSpec hostSpec : hosts) { + // A list is used to store the exception since lambdas require references to outer variables to be + // final. This allows us to identify if an error occurred while creating the node monitoring worker. + final List exceptionList = new ArrayList<>(); this.submittedNodes.computeIfAbsent(hostSpec.getHost(), (key) -> { final ExecutorService nodeExecutorServiceCopy = this.nodeExecutorService; if (nodeExecutorServiceCopy != null) { - this.nodeExecutorService.submit( - this.getNodeMonitoringWorker(hostSpec, this.writerHostSpec.get())); + try { + this.nodeExecutorService.submit( + this.getNodeMonitoringWorker(hostSpec, this.writerHostSpec.get())); + } catch (SQLException e) { + exceptionList.add(e); + return null; + } } return true; }); + + if (!exceptionList.isEmpty()) { + throw exceptionList.get(0); + } } // It's not possible to call shutdown() on this.nodeExecutorService since more node may be added later. } @@ -351,12 +363,25 @@ public void monitor() { List hosts = this.nodeThreadsLatestTopology.get(); if (hosts != null && !this.nodeThreadsStop.get()) { for (HostSpec hostSpec : hosts) { + // A list is used to store the exception since lambdas require references to outer variables to be + // final. This allows us to identify if an error occurred while creating the node monitoring worker. + final List exceptionList = new ArrayList<>(); this.submittedNodes.computeIfAbsent(hostSpec.getHost(), (key) -> { - this.nodeExecutorService.submit( - this.getNodeMonitoringWorker(hostSpec, this.writerHostSpec.get())); + try { + this.nodeExecutorService.submit( + this.getNodeMonitoringWorker(hostSpec, this.writerHostSpec.get())); + } catch (SQLException e) { + exceptionList.add(e); + return null; + } + return true; }); + + if (!exceptionList.isEmpty()) { + throw exceptionList.get(0); + } } // It's not possible to call shutdown() on this.nodeExecutorService since more node may be added later. } @@ -416,6 +441,7 @@ public void monitor() { ex); } + throw ex; } finally { this.stop.set(true); this.shutdownNodeExecutorService(); @@ -474,11 +500,11 @@ protected boolean isInPanicMode() { } protected Runnable getNodeMonitoringWorker( - final HostSpec hostSpec, final @Nullable HostSpec writerHostSpec) { + final HostSpec hostSpec, final @Nullable HostSpec writerHostSpec) throws SQLException { return new NodeMonitoringWorker(this.getNewServicesContainer(), this, hostSpec, writerHostSpec); } - protected FullServicesContainer getNewServicesContainer() { + protected FullServicesContainer getNewServicesContainer() throws SQLException { return ServiceContainerUtility.createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 71df06662..45fda3943 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -365,7 +365,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( return new ReaderFailoverResult(null, null, false); } - protected FullServicesContainer getNewServicesContainer() { + protected FullServicesContainer getNewServicesContainer() throws SQLException { return ServiceContainerUtility.createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 48055df42..2c498151b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -141,7 +141,7 @@ private void submitTasks( final List currentTopology, final ExecutorService executorService, final CompletionService completionService, - final boolean singleTask) { + final boolean singleTask) throws SQLException { final HostSpec writerHost = Utils.getWriter(currentTopology); if (!singleTask) { completionService.submit( @@ -166,7 +166,7 @@ private void submitTasks( executorService.shutdown(); } - protected FullServicesContainer getNewServicesContainer() { + protected FullServicesContainer getNewServicesContainer() throws SQLException { // Each task should get its own FullServicesContainer since they execute concurrently and PluginService was not // designed to be thread-safe. return ServiceContainerUtility.createServiceContainer( diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/Monitor.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/Monitor.java index d4d89dc4c..fbdd55063 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/Monitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/Monitor.java @@ -27,7 +27,7 @@ public interface Monitor { * submitted during the call to {@link #start()}. Additionally, the monitoring loop should regularly update the last * activity timestamp so that the {@link MonitorService} can detect whether the monitor is stuck or not. */ - void monitor(); + void monitor() throws Exception; /** * Stops the monitoring tasks for this monitor and closes all resources. diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index 9353dbf7e..e8567cdc0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -16,6 +16,7 @@ package software.amazon.jdbc.util.monitoring; +import java.sql.SQLException; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -184,7 +185,7 @@ public T runIfAbsent( TargetDriverDialect driverDialect, Dialect dbDialect, Properties originalProps, - MonitorInitializer initializer) { + MonitorInitializer initializer) throws SQLException { CacheContainer cacheContainer = monitorCaches.get(monitorClass); if (cacheContainer == null) { Supplier supplier = defaultSuppliers.get(monitorClass); @@ -228,7 +229,7 @@ protected FullServicesContainer getNewServicesContainer( String driverProtocol, TargetDriverDialect driverDialect, Dialect dbDialect, - Properties originalProps) { + Properties originalProps) throws SQLException { final Properties propsCopy = PropertyUtils.copyProperties(originalProps); return ServiceContainerUtility.createServiceContainer( storageService, diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java index 638e99a4f..f4cc60fec 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java @@ -73,7 +73,8 @@ void cleanUp() throws Exception { @BeforeEach void init() throws SQLException { closeable = MockitoAnnotations.openMocks(this); - servicesContainer = new FullServicesContainerImpl(mockStorageService, mockMonitorService, mockTelemetryFactory); + servicesContainer = new FullServicesContainerImpl( + mockStorageService, mockMonitorService, mockConnectionProvider, mockTelemetryFactory); when(mockConnectionProvider.connect(any(), any(), any(), any(), any())).thenReturn(mockConnection); when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())).thenReturn(null); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java index 4412e2368..32e0ebb46 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HighestWeightHostSelector; import software.amazon.jdbc.HostListProvider; import software.amazon.jdbc.HostRole; @@ -61,6 +62,7 @@ class LimitlessRouterServiceImplTest { private static final String CLUSTER_ID = "someClusterId"; @Mock private EventPublisher mockEventPublisher; @Mock private MonitorService mockMonitorService; + @Mock private ConnectionProvider mockConnectionProvider; @Mock private TelemetryFactory mockTelemetryFactory; @Mock private PluginService mockPluginService; @Mock private HostListProvider mockHostListProvider; @@ -84,7 +86,8 @@ public void init() throws SQLException { when(mockHostListProvider.getClusterId()).thenReturn(CLUSTER_ID); this.storageService = new StorageServiceImpl(mockEventPublisher); - servicesContainer = new FullServicesContainerImpl(this.storageService, mockMonitorService, mockTelemetryFactory); + servicesContainer = new FullServicesContainerImpl( + this.storageService, mockMonitorService, mockConnectionProvider, mockTelemetryFactory); servicesContainer.setPluginService(mockPluginService); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index 38400e9a6..9251b27e5 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -1,314 +1,314 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.util.monitoring; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; - -import java.sql.SQLException; -import java.util.Collections; -import java.util.HashSet; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.connection.ConnectionService; -import software.amazon.jdbc.util.events.EventPublisher; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; - -class MonitorServiceImplTest { - @Mock StorageService mockStorageService; - @Mock ConnectionService mockConnectionService; - @Mock ConnectionProvider mockConnectionProvider; - @Mock TelemetryFactory mockTelemetryFactory; - @Mock TargetDriverDialect mockTargetDriverDialect; - @Mock Dialect mockDbDialect; - @Mock EventPublisher mockPublisher; - MonitorServiceImpl spyMonitorService; - private AutoCloseable closeable; - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); - doNothing().when(spyMonitorService).initCleanupThread(anyInt()); - - try { - doReturn(mockConnectionService).when(spyMonitorService) - .getConnectionService(any(), any(), any(), any(), any(), any(), any(), any()); - } catch (SQLException e) { - Assertions.fail( - "Encountered exception while stubbing MonitorServiceImpl#getConnectionService: " + e.getMessage()); - } - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - spyMonitorService.releaseResources(); - } - - @Test - public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - monitor.state.set(MonitorState.ERROR); - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(newMonitor); - assertNotEquals(monitor, newMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, newMonitor.getState()); - } - - @Test - public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - 1, // heartbeat times out immediately - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(newMonitor); - assertNotEquals(monitor, newMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, newMonitor.getState()); - } - - @Test - public void testMonitorExpired() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - // checkMonitors() should detect the expiration timeout and stop/remove the monitor. - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - // monitor should have been removed when checkMonitors() was called. - assertNull(newMonitor); - } - - @Test - public void testMonitorMismatch() { - assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( - CustomEndpointMonitorImpl.class, - "testMonitor", - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor - // service should detect this and throw an exception. - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - )); - } - - @Test - public void testRemove() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - assertNotNull(monitor); - - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); - assertEquals(monitor, removedMonitor); - assertEquals(MonitorState.RUNNING, monitor.getState()); - } - - @Test - public void testStopAndRemove() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - "jdbc:postgresql://somehost/somedb", - "someProtocol", - mockTargetDriverDialect, - mockDbDialect, - new Properties(), - (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) - ); - assertNotNull(monitor); - - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - spyMonitorService.stopAndRemove(NoOpMonitor.class, key); - assertNull(spyMonitorService.get(NoOpMonitor.class, key)); - assertEquals(MonitorState.STOPPED, monitor.getState()); - } - - static class NoOpMonitor extends AbstractMonitor { - protected NoOpMonitor( - MonitorService monitorService, - long terminationTimeoutSec) { - super(terminationTimeoutSec); - } - - @Override - public void monitor() { - // do nothing. - } - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.util.monitoring; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertNotEquals; +// import static org.junit.jupiter.api.Assertions.assertNotNull; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyInt; +// import static org.mockito.Mockito.doNothing; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// +// import java.sql.SQLException; +// import java.util.Collections; +// import java.util.HashSet; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.Assertions; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.ConnectionProvider; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.connection.ConnectionService; +// import software.amazon.jdbc.util.events.EventPublisher; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// +// class MonitorServiceImplTest { +// @Mock StorageService mockStorageService; +// @Mock ConnectionService mockConnectionService; +// @Mock ConnectionProvider mockConnectionProvider; +// @Mock TelemetryFactory mockTelemetryFactory; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// @Mock Dialect mockDbDialect; +// @Mock EventPublisher mockPublisher; +// MonitorServiceImpl spyMonitorService; +// private AutoCloseable closeable; +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); +// doNothing().when(spyMonitorService).initCleanupThread(anyInt()); +// +// try { +// doReturn(mockConnectionService).when(spyMonitorService) +// .getConnectionService(any(), any(), any(), any(), any(), any(), any(), any()); +// } catch (SQLException e) { +// Assertions.fail( +// "Encountered exception while stubbing MonitorServiceImpl#getConnectionService: " + e.getMessage()); +// } +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// spyMonitorService.releaseResources(); +// } +// +// @Test +// public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// monitor.state.set(MonitorState.ERROR); +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(newMonitor); +// assertNotEquals(monitor, newMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, newMonitor.getState()); +// } +// +// @Test +// public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// 1, // heartbeat times out immediately +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(newMonitor); +// assertNotEquals(monitor, newMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, newMonitor.getState()); +// } +// +// @Test +// public void testMonitorExpired() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// // checkMonitors() should detect the expiration timeout and stop/remove the monitor. +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// // monitor should have been removed when checkMonitors() was called. +// assertNull(newMonitor); +// } +// +// @Test +// public void testMonitorMismatch() { +// assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( +// CustomEndpointMonitorImpl.class, +// "testMonitor", +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor +// // service should detect this and throw an exception. +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// )); +// } +// +// @Test +// public void testRemove() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// assertNotNull(monitor); +// +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); +// assertEquals(monitor, removedMonitor); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// } +// +// @Test +// public void testStopAndRemove() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// "jdbc:postgresql://somehost/somedb", +// "someProtocol", +// mockTargetDriverDialect, +// mockDbDialect, +// new Properties(), +// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) +// ); +// assertNotNull(monitor); +// +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// spyMonitorService.stopAndRemove(NoOpMonitor.class, key); +// assertNull(spyMonitorService.get(NoOpMonitor.class, key)); +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// } +// +// static class NoOpMonitor extends AbstractMonitor { +// protected NoOpMonitor( +// MonitorService monitorService, +// long terminationTimeoutSec) { +// super(terminationTimeoutSec); +// } +// +// @Override +// public void monitor() { +// // do nothing. +// } +// } +// } From 3f7fb28fdf5736b742f1b3d9c4e658fa58e26c63 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 3 Sep 2025 17:11:10 -0700 Subject: [PATCH 34/54] Failover test passing --- .../java/software/amazon/jdbc/PartialPluginService.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 34391c3a5..c6b494651 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -510,8 +510,7 @@ public Connection forceConnect( final HostSpec hostSpec, final Properties props) throws SQLException { - throw new UnsupportedOperationException( - Messages.get("PartialPluginService.unexpectedMethodCall", new Object[] {"forceConnect"})); + return this.forceConnect(hostSpec, props, null); } @Override @@ -520,8 +519,8 @@ public Connection forceConnect( final Properties props, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { - throw new UnsupportedOperationException( - Messages.get("PartialPluginService.unexpectedMethodCall", new Object[] {"forceConnect"})); + return this.pluginManager.forceConnect( + this.driverProtocol, hostSpec, props, true, pluginToSkip); } private void updateHostAvailability(final List hosts) { From edeea73f4fff7b877a0ebed932528a7abd400f79 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 5 Sep 2025 16:19:21 -0700 Subject: [PATCH 35/54] PR suggestions --- .../ClusterAwareReaderFailoverHandler.java | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index d0512cf55..7cb4f7b5a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -19,7 +19,6 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -308,15 +307,11 @@ private ReaderFailoverResult getConnectionFromHostGroup(final List hos final ExecutorService executor = ExecutorFactory.newFixedThreadPool(2, "failover"); final CompletionService completionService = new ExecutorCompletionService<>(executor); - // The ConnectionAttemptTask threads should have their own plugin services since they execute concurrently and - // PluginService was not designed to be thread-safe. - List pluginServices = Arrays.asList(getNewPluginService(), getNewPluginService()); - try { for (int i = 0; i < hosts.size(); i += 2) { // submit connection attempt tasks in batches of 2 final ReaderFailoverResult result = - getResultFromNextTaskBatch(hosts, executor, completionService, pluginServices, i); + getResultFromNextTaskBatch(hosts, executor, completionService, i); if (result.isConnected() || result.getException() != null) { return result; } @@ -339,14 +334,13 @@ private ReaderFailoverResult getResultFromNextTaskBatch( final List hosts, final ExecutorService executor, final CompletionService completionService, - final List pluginServices, final int i) throws SQLException { ReaderFailoverResult result; final int numTasks = i + 1 < hosts.size() ? 2 : 1; completionService.submit( new ConnectionAttemptTask( this.connectionService, - pluginServices.get(0), + getNewPluginService(), this.hostAvailabilityMap, hosts.get(i), this.props, @@ -355,7 +349,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( completionService.submit( new ConnectionAttemptTask( this.connectionService, - pluginServices.get(1), + getNewPluginService(), this.hostAvailabilityMap, hosts.get(i + 1), this.props, @@ -398,6 +392,8 @@ private ReaderFailoverResult getNextResult(final CompletionService Date: Mon, 8 Sep 2025 09:57:26 -0700 Subject: [PATCH 36/54] Rename to ServiceUtility --- .../monitoring/ClusterTopologyMonitorImpl.java | 4 ++-- .../failover/ClusterAwareReaderFailoverHandler.java | 4 ++-- .../failover/ClusterAwareWriterFailoverHandler.java | 4 ++-- ...viceContainerUtility.java => ServiceUtility.java} | 12 ++++++------ .../jdbc/util/monitoring/MonitorServiceImpl.java | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) rename wrapper/src/main/java/software/amazon/jdbc/util/{ServiceContainerUtility.java => ServiceUtility.java} (89%) diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java index 7245eb50c..c37a95bba 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java @@ -52,7 +52,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.ServiceContainerUtility; +import software.amazon.jdbc.util.ServiceUtility; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.SynchronousExecutor; import software.amazon.jdbc.util.Utils; @@ -505,7 +505,7 @@ protected Runnable getNodeMonitoringWorker( } protected FullServicesContainer getNewServicesContainer() throws SQLException { - return ServiceContainerUtility.createServiceContainer( + return ServiceUtility.getInstance().createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), this.servicesContainer.getDefaultConnectionProvider(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index fae727f0f..e58378ddb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; -import software.amazon.jdbc.util.ServiceContainerUtility; +import software.amazon.jdbc.util.ServiceUtility; import software.amazon.jdbc.util.Utils; /** @@ -359,7 +359,7 @@ private ReaderFailoverResult getResultFromNextTaskBatch( } protected FullServicesContainer getNewServicesContainer() throws SQLException { - return ServiceContainerUtility.createServiceContainer( + return ServiceUtility.getInstance().createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 2c498151b..9afe693cd 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -39,7 +39,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; -import software.amazon.jdbc.util.ServiceContainerUtility; +import software.amazon.jdbc.util.ServiceUtility; import software.amazon.jdbc.util.Utils; /** @@ -169,7 +169,7 @@ private void submitTasks( protected FullServicesContainer getNewServicesContainer() throws SQLException { // Each task should get its own FullServicesContainer since they execute concurrently and PluginService was not // designed to be thread-safe. - return ServiceContainerUtility.createServiceContainer( + return ServiceUtility.getInstance().createServiceContainer( this.servicesContainer.getStorageService(), this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java similarity index 89% rename from wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java rename to wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java index f07f1f01e..84d6be614 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceContainerUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java @@ -28,17 +28,17 @@ import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; -public class ServiceContainerUtility { - private static volatile ServiceContainerUtility instance; +public class ServiceUtility { + private static volatile ServiceUtility instance; private static final ReentrantLock initLock = new ReentrantLock(); - private ServiceContainerUtility() { + private ServiceUtility() { if (instance != null) { throw new IllegalStateException("ServiceContainerUtility singleton instance already exists."); } } - public static ServiceContainerUtility getInstance() { + public static ServiceUtility getInstance() { if (instance != null) { return instance; } @@ -46,7 +46,7 @@ public static ServiceContainerUtility getInstance() { initLock.lock(); try { if (instance == null) { - instance = new ServiceContainerUtility(); + instance = new ServiceUtility(); } } finally { initLock.unlock(); @@ -55,7 +55,7 @@ public static ServiceContainerUtility getInstance() { return instance; } - public static FullServicesContainer createServiceContainer( + public FullServicesContainer createServiceContainer( StorageService storageService, MonitorService monitorService, ConnectionProvider connectionProvider, diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index e8567cdc0..5f8b21e68 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; -import software.amazon.jdbc.util.ServiceContainerUtility; +import software.amazon.jdbc.util.ServiceUtility; import software.amazon.jdbc.util.events.DataAccessEvent; import software.amazon.jdbc.util.events.Event; import software.amazon.jdbc.util.events.EventPublisher; @@ -231,7 +231,7 @@ protected FullServicesContainer getNewServicesContainer( Dialect dbDialect, Properties originalProps) throws SQLException { final Properties propsCopy = PropertyUtils.copyProperties(originalProps); - return ServiceContainerUtility.createServiceContainer( + return ServiceUtility.getInstance().createServiceContainer( storageService, this, connectionProvider, From 77dccce6e9f13b6a9287991116351a96860c27fe Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 15 Sep 2025 14:17:10 -0700 Subject: [PATCH 37/54] Fix unit tests --- ...ClusterAwareReaderFailoverHandlerTest.java | 797 ++++++++-------- ...ClusterAwareWriterFailoverHandlerTest.java | 745 ++++++++------- .../FailoverConnectionPluginTest.java | 892 +++++++++--------- .../monitoring/MonitorServiceImplTest.java | 629 ++++++------ 4 files changed, 1529 insertions(+), 1534 deletions(-) diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java index 7f7d3567c..e4afa0936 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandlerTest.java @@ -1,400 +1,397 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.failover; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertFalse; -// import static org.junit.jupiter.api.Assertions.assertNull; -// import static org.junit.jupiter.api.Assertions.assertSame; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.when; -// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; -// import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.EnumSet; -// import java.util.List; -// import java.util.Map; -// import java.util.Properties; -// import java.util.Set; -// import java.util.concurrent.TimeUnit; -// import java.util.stream.Collectors; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import org.mockito.stubbing.Answer; -// import software.amazon.jdbc.ConnectionPluginManager; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.connection.ConnectionService; -// -// class ClusterAwareReaderFailoverHandlerTest { -// @Mock FullServicesContainer mockContainer; -// @Mock ConnectionService mockConnectionService; -// @Mock PluginService mockPluginService; -// @Mock ConnectionPluginManager mockPluginManager; -// @Mock Connection mockConnection; -// -// private AutoCloseable closeable; -// private final Properties properties = new Properties(); -// private final List defaultHosts = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer").port(1234).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader1").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader2").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader3").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader4").port(1234).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader5").port(1234).role(HostRole.READER).build() -// ); -// -// @BeforeEach -// void setUp() { -// closeable = MockitoAnnotations.openMocks(this); -// when(mockContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); -// when(mockContainer.getPluginService()).thenReturn(mockPluginService); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// } -// -// @Test -// public void testFailover() throws SQLException { -// // original host list: [active writer, active reader, current connection (reader), active -// // reader, down reader, active reader] -// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] -// // connection attempts are made in pairs using the above list -// // expected test result: successful connection for host at index 4 -// final List hosts = defaultHosts; -// final int currentHostIndex = 2; -// final int successHostIndex = 4; -// for (int i = 0; i < hosts.size(); i++) { -// if (i != successHostIndex) { -// final SQLException exception = new SQLException("exception", "08S01", null); -// when(mockConnectionService.open(hosts.get(i), properties)) -// .thenThrow(exception); -// when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); -// } else { -// when(mockConnectionService.open(hosts.get(i), properties)).thenReturn(mockConnection); -// } -// } -// -// when(mockPluginService.getTargetDriverDialect()).thenReturn(null); -// -// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// final ReaderFailoverHandler target = getSpyFailoverHandler(); -// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); -// -// assertTrue(result.isConnected()); -// assertSame(mockConnection, result.getConnection()); -// assertEquals(hosts.get(successHostIndex), result.getHost()); -// -// final HostSpec successHost = hosts.get(successHostIndex); -// final Map availabilityMap = target.getHostAvailabilityMap(); -// Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); -// assertTrue(unavailableHosts.size() >= 4); -// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); -// } -// -// private Set getHostsWithGivenAvailability( -// Map availabilityMap, HostAvailability availability) { -// return availabilityMap.entrySet().stream() -// .filter((entry) -> availability.equals(entry.getValue())) -// .map(Map.Entry::getKey) -// .collect(Collectors.toSet()); -// } -// -// @Test -// public void testFailover_timeout() throws SQLException { -// // original host list: [active writer, active reader, current connection (reader), active -// // reader, down reader, active reader] -// // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] -// // connection attempts are made in pairs using the above list -// // expected test result: failure to get reader since process is limited to 5s and each attempt -// // to connect takes 20s -// final List hosts = defaultHosts; -// final int currentHostIndex = 2; -// for (HostSpec host : hosts) { -// when(mockConnectionService.open(host, properties)) -// .thenAnswer((Answer) invocation -> { -// Thread.sleep(20000); -// return mockConnection; -// }); -// } -// -// hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); -// -// final long startTimeNano = System.nanoTime(); -// final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); -// final long durationNano = System.nanoTime() - startTimeNano; -// -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// -// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements -// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); -// } -// -// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() throws SQLException { -// ClusterAwareReaderFailoverHandler handler = -// spy(new ClusterAwareReaderFailoverHandler(mockContainer, mockConnectionService, properties)); -// doReturn(mockPluginService).when(handler).getNewPluginService(); -// return handler; -// } -// -// private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( -// int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) throws SQLException { -// ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( -// mockContainer, mockConnectionService, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); -// ClusterAwareReaderFailoverHandler spyHandler = spy(handler); -// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); -// return spyHandler; -// } -// -// @Test -// public void testFailover_nullOrEmptyHostList() throws SQLException { -// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); -// final HostSpec currentHost = -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); -// -// ReaderFailoverResult result = target.failover(null, currentHost); -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// -// final List hosts = new ArrayList<>(); -// result = target.failover(hosts, currentHost); -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// } -// -// @Test -// public void testGetReader_connectionSuccess() throws SQLException { -// // even number of connection attempts -// // first connection attempt to return succeeds, second attempt cancelled -// // expected test result: successful connection for host at index 2 -// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) -// final HostSpec slowHost = hosts.get(1); -// final HostSpec fastHost = hosts.get(2); -// when(mockConnectionService.open(slowHost, properties)) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(20000); -// return mockConnection; -// }); -// when(mockConnectionService.open(eq(fastHost), eq(properties))).thenReturn(mockConnection); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ReaderFailoverHandler target = getSpyFailoverHandler(); -// final ReaderFailoverResult result = target.getReaderConnection(hosts); -// -// assertTrue(result.isConnected()); -// assertSame(mockConnection, result.getConnection()); -// assertEquals(hosts.get(2), result.getHost()); -// -// Map availabilityMap = target.getHostAvailabilityMap(); -// assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); -// assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); -// } -// -// @Test -// public void testGetReader_connectionFailure() throws SQLException { -// // odd number of connection attempts -// // first connection attempt to return fails -// // expected test result: failure to get reader -// final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) -// when(mockConnectionService.open(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ReaderFailoverHandler target = getSpyFailoverHandler(); -// final ReaderFailoverResult result = target.getReaderConnection(hosts); -// -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// } -// -// @Test -// public void testGetReader_connectionAttemptsTimeout() throws SQLException { -// // connection attempts time out before they can succeed -// // first connection attempt to return times out -// // expected test result: failure to get reader -// final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) -// when(mockConnectionService.open(any(), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// try { -// Thread.sleep(5000); -// } catch (InterruptedException exception) { -// // ignore -// } -// return mockConnection; -// }); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); -// final ReaderFailoverResult result = target.getReaderConnection(hosts); -// -// assertFalse(result.isConnected()); -// assertNull(result.getConnection()); -// assertNull(result.getHost()); -// } -// -// @Test -// public void testGetHostTuplesByPriority() throws SQLException { -// final List originalHosts = defaultHosts; -// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); -// final List hostsByPriority = target.getHostsByPriority(originalHosts); -// -// int i = 0; -// -// // expecting active readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { -// i++; -// } -// -// // expecting a writer -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.WRITER) { -// i++; -// } -// -// // expecting down readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { -// i++; -// } -// -// assertEquals(hostsByPriority.size(), i); -// } -// -// @Test -// public void testGetReaderTuplesByPriority() throws SQLException { -// final List originalHosts = defaultHosts; -// originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); -// originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); -// final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); -// -// int i = 0; -// -// // expecting active readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { -// i++; -// } -// -// // expecting down readers -// while (i < hostsByPriority.size() -// && hostsByPriority.get(i).getRole() == HostRole.READER -// && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { -// i++; -// } -// -// assertEquals(hostsByPriority.size(), i); -// } -// -// @Test -// public void testHostFailoverStrictReaderEnabled() throws SQLException { -// final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer").port(1234).role(HostRole.WRITER).build(); -// final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader1").port(1234).role(HostRole.READER).build(); -// final List hosts = Arrays.asList(writer, reader); -// -// Dialect mockDialect = Mockito.mock(Dialect.class); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// -// final ClusterAwareReaderFailoverHandler target = -// getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); -// -// // The writer is included because the original writer has likely become a reader. -// List expectedHostsByPriority = Arrays.asList(reader, writer); -// -// List hostsByPriority = target.getHostsByPriority(hosts); -// assertEquals(expectedHostsByPriority, hostsByPriority); -// -// // Should pick the reader even if unavailable. The unavailable reader will be lower priority than the writer. -// reader.setAvailability(HostAvailability.NOT_AVAILABLE); -// expectedHostsByPriority = Arrays.asList(writer, reader); -// -// hostsByPriority = target.getHostsByPriority(hosts); -// assertEquals(expectedHostsByPriority, hostsByPriority); -// -// // Writer node will only be picked if it is the only node in topology; -// List expectedWriterHost = Collections.singletonList(writer); -// -// hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); -// assertEquals(expectedWriterHost, hostsByPriority); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.failover; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_FAILOVER_TIMEOUT; +import static software.amazon.jdbc.plugin.failover.ClusterAwareReaderFailoverHandler.DEFAULT_READER_CONNECT_TIMEOUT; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.stubbing.Answer; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.FullServicesContainer; + +class ClusterAwareReaderFailoverHandlerTest { + @Mock FullServicesContainer mockContainer1; + @Mock FullServicesContainer mockContainer2; + @Mock PluginService mockPluginService; + @Mock Connection mockConnection; + + private AutoCloseable closeable; + private final Properties properties = new Properties(); + private final List defaultHosts = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer").port(1234).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader1").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader2").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader3").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader4").port(1234).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader5").port(1234).role(HostRole.READER).build() + ); + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + when(mockContainer1.getPluginService()).thenReturn(mockPluginService); + when(mockContainer2.getPluginService()).thenReturn(mockPluginService); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + public void testFailover() throws SQLException { + // original host list: [active writer, active reader, current connection (reader), active + // reader, down reader, active reader] + // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] + // connection attempts are made in pairs using the above list + // expected test result: successful connection for host at index 4 + final List hosts = defaultHosts; + final int currentHostIndex = 2; + final int successHostIndex = 4; + for (int i = 0; i < hosts.size(); i++) { + if (i != successHostIndex) { + final SQLException exception = new SQLException("exception", "08S01", null); + when(mockPluginService.forceConnect(hosts.get(i), properties)) + .thenThrow(exception); + when(mockPluginService.isNetworkException(exception, null)).thenReturn(true); + } else { + when(mockPluginService.forceConnect(hosts.get(i), properties)).thenReturn(mockConnection); + } + } + + when(mockPluginService.getTargetDriverDialect()).thenReturn(null); + + hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + + final ReaderFailoverHandler target = getSpyFailoverHandler(); + final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); + + assertTrue(result.isConnected()); + assertSame(mockConnection, result.getConnection()); + assertEquals(hosts.get(successHostIndex), result.getHost()); + + final HostSpec successHost = hosts.get(successHostIndex); + final Map availabilityMap = target.getHostAvailabilityMap(); + Set unavailableHosts = getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE); + assertTrue(unavailableHosts.size() >= 4); + assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(successHost.getHost())); + } + + private Set getHostsWithGivenAvailability( + Map availabilityMap, HostAvailability availability) { + return availabilityMap.entrySet().stream() + .filter((entry) -> availability.equals(entry.getValue())) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + + @Test + public void testFailover_timeout() throws SQLException { + // original host list: [active writer, active reader, current connection (reader), active + // reader, down reader, active reader] + // priority order by index (the subsets will be shuffled): [[1, 3, 5], 0, [2, 4]] + // connection attempts are made in pairs using the above list + // expected test result: failure to get reader since process is limited to 5s and each attempt + // to connect takes 20s + final List hosts = defaultHosts; + final int currentHostIndex = 2; + for (HostSpec host : hosts) { + when(mockPluginService.forceConnect(host, properties)) + .thenAnswer((Answer) invocation -> { + Thread.sleep(20000); + return mockConnection; + }); + } + + hosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + hosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + + final ReaderFailoverHandler target = getSpyFailoverHandler(5000, 30000, false); + + final long startTimeNano = System.nanoTime(); + final ReaderFailoverResult result = target.failover(hosts, hosts.get(currentHostIndex)); + final long durationNano = System.nanoTime() - startTimeNano; + + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + + // 5s is a max allowed failover timeout; add 1s for inaccurate measurements + assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); + } + + private ClusterAwareReaderFailoverHandler getSpyFailoverHandler() throws SQLException { + ClusterAwareReaderFailoverHandler handler = + spy(new ClusterAwareReaderFailoverHandler(mockContainer1, properties)); + doReturn(mockContainer2).when(handler).getNewServicesContainer(); + return handler; + } + + private ClusterAwareReaderFailoverHandler getSpyFailoverHandler( + int maxFailoverTimeoutMs, int timeoutMs, boolean isStrictReaderRequired) throws SQLException { + ClusterAwareReaderFailoverHandler handler = new ClusterAwareReaderFailoverHandler( + mockContainer1, properties, maxFailoverTimeoutMs, timeoutMs, isStrictReaderRequired); + ClusterAwareReaderFailoverHandler spyHandler = spy(handler); + doReturn(mockContainer2).when(spyHandler).getNewServicesContainer(); + return spyHandler; + } + + @Test + public void testFailover_nullOrEmptyHostList() throws SQLException { + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); + final HostSpec currentHost = + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("writer").port(1234).build(); + + ReaderFailoverResult result = target.failover(null, currentHost); + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + + final List hosts = new ArrayList<>(); + result = target.failover(hosts, currentHost); + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + } + + @Test + public void testGetReader_connectionSuccess() throws SQLException { + // even number of connection attempts + // first connection attempt to return succeeds, second attempt cancelled + // expected test result: successful connection for host at index 2 + final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) + final HostSpec slowHost = hosts.get(1); + final HostSpec fastHost = hosts.get(2); + when(mockPluginService.forceConnect(slowHost, properties)) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(20000); + return mockConnection; + }); + when(mockPluginService.forceConnect(eq(fastHost), eq(properties))).thenReturn(mockConnection); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ReaderFailoverHandler target = getSpyFailoverHandler(); + final ReaderFailoverResult result = target.getReaderConnection(hosts); + + assertTrue(result.isConnected()); + assertSame(mockConnection, result.getConnection()); + assertEquals(hosts.get(2), result.getHost()); + + Map availabilityMap = target.getHostAvailabilityMap(); + assertTrue(getHostsWithGivenAvailability(availabilityMap, HostAvailability.NOT_AVAILABLE).isEmpty()); + assertEquals(HostAvailability.AVAILABLE, availabilityMap.get(fastHost.getHost())); + } + + @Test + public void testGetReader_connectionFailure() throws SQLException { + // odd number of connection attempts + // first connection attempt to return fails + // expected test result: failure to get reader + final List hosts = defaultHosts.subList(0, 4); // 3 connection attempts (writer not attempted) + when(mockPluginService.forceConnect(any(), eq(properties))).thenThrow(new SQLException("exception", "08S01", null)); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ReaderFailoverHandler target = getSpyFailoverHandler(); + final ReaderFailoverResult result = target.getReaderConnection(hosts); + + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + } + + @Test + public void testGetReader_connectionAttemptsTimeout() throws SQLException { + // connection attempts time out before they can succeed + // first connection attempt to return times out + // expected test result: failure to get reader + final List hosts = defaultHosts.subList(0, 3); // 2 connection attempts (writer not attempted) + when(mockPluginService.forceConnect(any(), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + try { + Thread.sleep(5000); + } catch (InterruptedException exception) { + // ignore + } + return mockConnection; + }); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(60000, 1000, false); + final ReaderFailoverResult result = target.getReaderConnection(hosts); + + assertFalse(result.isConnected()); + assertNull(result.getConnection()); + assertNull(result.getHost()); + } + + @Test + public void testGetHostTuplesByPriority() throws SQLException { + final List originalHosts = defaultHosts; + originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); + + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); + final List hostsByPriority = target.getHostsByPriority(originalHosts); + + int i = 0; + + // expecting active readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { + i++; + } + + // expecting a writer + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.WRITER) { + i++; + } + + // expecting down readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { + i++; + } + + assertEquals(hostsByPriority.size(), i); + } + + @Test + public void testGetReaderTuplesByPriority() throws SQLException { + final List originalHosts = defaultHosts; + originalHosts.get(2).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(4).setAvailability(HostAvailability.NOT_AVAILABLE); + originalHosts.get(5).setAvailability(HostAvailability.NOT_AVAILABLE); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ClusterAwareReaderFailoverHandler target = getSpyFailoverHandler(); + final List hostsByPriority = target.getReaderHostsByPriority(originalHosts); + + int i = 0; + + // expecting active readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.AVAILABLE) { + i++; + } + + // expecting down readers + while (i < hostsByPriority.size() + && hostsByPriority.get(i).getRole() == HostRole.READER + && hostsByPriority.get(i).getAvailability() == HostAvailability.NOT_AVAILABLE) { + i++; + } + + assertEquals(hostsByPriority.size(), i); + } + + @Test + public void testHostFailoverStrictReaderEnabled() throws SQLException { + final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer").port(1234).role(HostRole.WRITER).build(); + final HostSpec reader = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader1").port(1234).role(HostRole.READER).build(); + final List hosts = Arrays.asList(writer, reader); + + Dialect mockDialect = Mockito.mock(Dialect.class); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + + final ClusterAwareReaderFailoverHandler target = + getSpyFailoverHandler(DEFAULT_FAILOVER_TIMEOUT, DEFAULT_READER_CONNECT_TIMEOUT, true); + + // The writer is included because the original writer has likely become a reader. + List expectedHostsByPriority = Arrays.asList(reader, writer); + + List hostsByPriority = target.getHostsByPriority(hosts); + assertEquals(expectedHostsByPriority, hostsByPriority); + + // Should pick the reader even if unavailable. The unavailable reader will have lower priority than the writer. + reader.setAvailability(HostAvailability.NOT_AVAILABLE); + expectedHostsByPriority = Arrays.asList(writer, reader); + + hostsByPriority = target.getHostsByPriority(hosts); + assertEquals(expectedHostsByPriority, hostsByPriority); + + // Writer node will only be picked if it is the only node in topology; + List expectedWriterHost = Collections.singletonList(writer); + + hostsByPriority = target.getHostsByPriority(Collections.singletonList(writer)); + assertEquals(expectedWriterHost, hostsByPriority); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java index 98d392cdb..b17ecbb95 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java @@ -1,373 +1,372 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.failover; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertFalse; -// import static org.junit.jupiter.api.Assertions.assertSame; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.ArgumentMatchers.refEq; -// import static org.mockito.Mockito.atLeastOnce; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.Arrays; -// import java.util.EnumSet; -// import java.util.List; -// import java.util.Properties; -// import java.util.concurrent.TimeUnit; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.ArgumentMatchers; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import org.mockito.stubbing.Answer; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.connection.ConnectionService; -// -// class ClusterAwareWriterFailoverHandlerTest { -// @Mock FullServicesContainer mockContainer; -// @Mock ConnectionService mockConnectionService; -// @Mock PluginService mockPluginService; -// @Mock Connection mockConnection; -// @Mock ReaderFailoverHandler mockReaderFailoverHandler; -// @Mock Connection mockWriterConnection; -// @Mock Connection mockNewWriterConnection; -// @Mock Connection mockReaderAConnection; -// @Mock Connection mockReaderBConnection; -// @Mock Dialect mockDialect; -// -// private AutoCloseable closeable; -// private final Properties properties = new Properties(); -// private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("new-writer-host").build(); -// private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer-host").build(); -// private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader-a-host").build(); -// private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader-b-host").build(); -// private final List topology = Arrays.asList(writer, readerA, readerB); -// private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); -// -// @BeforeEach -// void setUp() { -// closeable = MockitoAnnotations.openMocks(this); -// when(mockContainer.getPluginService()).thenReturn(mockPluginService); -// writer.addAlias("writer-host"); -// newWriterHost.addAlias("new-writer-host"); -// readerA.addAlias("reader-a-host"); -// readerB.addAlias("reader-b-host"); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// } -// -// @Test -// public void testReconnectToWriter_taskBReaderException() throws SQLException { -// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockConnection); -// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenThrow(SQLException.class); -// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); -// -// when(mockPluginService.getAllHosts()).thenReturn(topology); -// -// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertFalse(result.isNewHost()); -// assertSame(result.getNewConnection(), mockConnection); -// -// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); -// } -// -// private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( -// final int failoverTimeoutMs, -// final int readTopologyIntervalMs, -// final int reconnectWriterIntervalMs) throws SQLException { -// ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( -// mockContainer, -// mockConnectionService, -// mockReaderFailoverHandler, -// properties, -// failoverTimeoutMs, -// readTopologyIntervalMs, -// reconnectWriterIntervalMs); -// -// ClusterAwareWriterFailoverHandler spyHandler = spy(handler); -// doReturn(mockPluginService).when(spyHandler).getNewPluginService(); -// return spyHandler; -// } -// -// /** -// * Verify that writer failover handler can re-connect to a current writer node. -// * -// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. -// * TaskA: successfully re-connect to initial writer; return new connection. -// * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. -// * Expected test result: new connection by taskA. -// */ -// @Test -// public void testReconnectToWriter_SlowReaderA() throws SQLException { -// when(mockConnectionService.open(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); -// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); -// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); -// when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); -// -// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return new ReaderFailoverResult(mockReaderAConnection, readerA, true); -// }); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertFalse(result.isNewHost()); -// assertSame(result.getNewConnection(), mockWriterConnection); -// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); -// } -// -// /** -// * Verify that writer failover handler can re-connect to a current writer node. -// * -// *

Topology: no changes. -// * TaskA: successfully re-connect to writer; return new connection. -// * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). -// * Expected test result: new connection by taskA. -// */ -// @Test -// public void testReconnectToWriter_taskBDefers() throws SQLException { -// when(mockConnectionService.open(refEq(writer), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return mockWriterConnection; -// }); -// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenThrow(SQLException.class); -// -// when(mockPluginService.getAllHosts()).thenReturn(topology); -// -// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertFalse(result.isNewHost()); -// assertSame(result.getNewConnection(), mockWriterConnection); -// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); -// } -// -// /** -// * Verify that writer failover handler can re-connect to a new writer node. -// * -// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. -// * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more -// * time than taskB. -// * TaskB: successfully connect to readerA and then to new-writer. -// * Expected test result: new connection to writer by taskB. -// */ -// @Test -// public void testConnectToReaderA_SlowWriter() throws SQLException { -// when(mockConnectionService.open(refEq(writer), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return mockWriterConnection; -// }); -// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); -// -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertTrue(result.isNewHost()); -// assertSame(result.getNewConnection(), mockNewWriterConnection); -// assertEquals(3, result.getTopology().size()); -// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); -// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); -// } -// -// /** -// * Verify that writer failover handler can re-connect to a new writer node. -// * -// *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. -// * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). -// * TaskB: successfully connect to readerA and then to new-writer. -// * Expected test result: new connection to writer by taskB. -// */ -// @Test -// public void testConnectToReaderA_taskADefers() throws SQLException { -// when(mockConnectionService.open(writer, properties)).thenReturn(mockConnection); -// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(5000); -// return mockNewWriterConnection; -// }); -// -// final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertTrue(result.isConnected()); -// assertTrue(result.isNewHost()); -// assertSame(result.getNewConnection(), mockNewWriterConnection); -// assertEquals(4, result.getTopology().size()); -// assertEquals("new-writer-host", result.getTopology().get(0).getHost()); -// -// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); -// assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); -// } -// -// /** -// * Verify that writer failover handler fails to re-connect to any writer node. -// * -// *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. -// * TaskA: fail to re-connect to writer due to failover timeout. -// * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. -// * Expected test result: no connection. -// */ -// @Test -// public void testFailedToConnect_failoverTimeout() throws SQLException { -// when(mockConnectionService.open(refEq(writer), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(30000); -// return mockWriterConnection; -// }); -// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))) -// .thenAnswer( -// (Answer) -// invocation -> { -// Thread.sleep(30000); -// return mockNewWriterConnection; -// }); -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); -// -// final long startTimeNano = System.nanoTime(); -// final WriterFailoverResult result = target.failover(topology); -// final long durationNano = System.nanoTime() - startTimeNano; -// -// assertFalse(result.isConnected()); -// assertFalse(result.isNewHost()); -// -// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); -// -// // 5s is a max allowed failover timeout; add 1s for inaccurate measurements -// assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); -// } -// -// /** -// * Verify that writer failover handler fails to re-connect to any writer node. -// * -// *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. -// * TaskA: fail to re-connect to writer due to exception. -// * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. -// * Expected test result: no connection. -// */ -// @Test -// public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { -// final SQLException exception = new SQLException("exception", "08S01", null); -// when(mockConnectionService.open(refEq(writer), eq(properties))).thenThrow(exception); -// when(mockConnectionService.open(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); -// when(mockConnectionService.open(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); -// when(mockConnectionService.open(refEq(newWriterHost), eq(properties))).thenThrow(exception); -// when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); -// -// when(mockPluginService.getAllHosts()).thenReturn(newTopology); -// -// when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) -// .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); -// -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); -// -// final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); -// final WriterFailoverResult result = target.failover(topology); -// -// assertFalse(result.isConnected()); -// assertFalse(result.isNewHost()); -// -// assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.failover; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.refEq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.stubbing.Answer; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.FullServicesContainer; + +class ClusterAwareWriterFailoverHandlerTest { + @Mock FullServicesContainer mockContainer1; + @Mock FullServicesContainer mockContainer2; + @Mock PluginService mockPluginService; + @Mock Connection mockConnection; + @Mock ReaderFailoverHandler mockReaderFailoverHandler; + @Mock Connection mockWriterConnection; + @Mock Connection mockNewWriterConnection; + @Mock Connection mockReaderAConnection; + @Mock Connection mockReaderBConnection; + @Mock Dialect mockDialect; + + private AutoCloseable closeable; + private final Properties properties = new Properties(); + private final HostSpec newWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("new-writer-host").build(); + private final HostSpec writer = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer-host").build(); + private final HostSpec readerA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader-a-host").build(); + private final HostSpec readerB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader-b-host").build(); + private final List topology = Arrays.asList(writer, readerA, readerB); + private final List newTopology = Arrays.asList(newWriterHost, readerA, readerB); + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + when(mockContainer1.getPluginService()).thenReturn(mockPluginService); + when(mockContainer2.getPluginService()).thenReturn(mockPluginService); + writer.addAlias("writer-host"); + newWriterHost.addAlias("new-writer-host"); + readerA.addAlias("reader-a-host"); + readerB.addAlias("reader-b-host"); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + public void testReconnectToWriter_taskBReaderException() throws SQLException { + when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockConnection); + when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenThrow(SQLException.class); + when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); + + when(mockPluginService.getAllHosts()).thenReturn(topology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertFalse(result.isNewHost()); + assertSame(result.getNewConnection(), mockConnection); + + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); + } + + private ClusterAwareWriterFailoverHandler getSpyFailoverHandler( + final int failoverTimeoutMs, + final int readTopologyIntervalMs, + final int reconnectWriterIntervalMs) throws SQLException { + ClusterAwareWriterFailoverHandler handler = new ClusterAwareWriterFailoverHandler( + mockContainer1, + mockReaderFailoverHandler, + properties, + failoverTimeoutMs, + readTopologyIntervalMs, + reconnectWriterIntervalMs); + + ClusterAwareWriterFailoverHandler spyHandler = spy(handler); + doReturn(mockContainer2).when(spyHandler).getNewServicesContainer(); + return spyHandler; + } + + /** + * Verify that writer failover handler can re-connect to a current writer node. + * + *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. + * TaskA: successfully re-connect to initial writer; return new connection. + * TaskB: successfully connect to readerA and then new writer, but it takes more time than taskA. + * Expected test result: new connection by taskA. + */ + @Test + public void testReconnectToWriter_SlowReaderA() throws SQLException { + when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); + when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); + when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); + when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return new ReaderFailoverResult(mockReaderAConnection, readerA, true); + }); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertFalse(result.isNewHost()); + assertSame(result.getNewConnection(), mockWriterConnection); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); + } + + /** + * Verify that writer failover handler can re-connect to a current writer node. + * + *

Topology: no changes. + * TaskA: successfully re-connect to writer; return new connection. + * TaskB: successfully connect to readerA and retrieve topology, but latest writer is not new (defer to taskA). + * Expected test result: new connection by taskA. + */ + @Test + public void testReconnectToWriter_taskBDefers() throws SQLException { + when(mockPluginService.forceConnect(refEq(writer), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return mockWriterConnection; + }); + when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); + + when(mockPluginService.getAllHosts()).thenReturn(topology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 2000, 2000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertFalse(result.isNewHost()); + assertSame(result.getNewConnection(), mockWriterConnection); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(writer.getHost())); + } + + /** + * Verify that writer failover handler can re-connect to a new writer node. + * + *

Topology: changes to [new-writer, reader-A, reader-B] for taskB, taskA sees no changes. + * taskA: successfully re-connect to writer; return connection to initial writer, but it takes more + * time than taskB. + * TaskB: successfully connect to readerA and then to new-writer. + * Expected test result: new connection to writer by taskB. + */ + @Test + public void testConnectToReaderA_SlowWriter() throws SQLException { + when(mockPluginService.forceConnect(refEq(writer), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return mockWriterConnection; + }); + when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); + + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertTrue(result.isNewHost()); + assertSame(result.getNewConnection(), mockNewWriterConnection); + assertEquals(3, result.getTopology().size()); + assertEquals("new-writer-host", result.getTopology().get(0).getHost()); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); + } + + /** + * Verify that writer failover handler can re-connect to a new writer node. + * + *

Topology: changes to [new-writer, initial-writer, reader-A, reader-B]. + * TaskA: successfully reconnect, but initial-writer is now a reader (defer to taskB). + * TaskB: successfully connect to readerA and then to new-writer. + * Expected test result: new connection to writer by taskB. + */ + @Test + public void testConnectToReaderA_taskADefers() throws SQLException { + when(mockPluginService.forceConnect(writer, properties)).thenReturn(mockConnection); + when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(5000); + return mockNewWriterConnection; + }); + + final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(60000, 5000, 5000); + final WriterFailoverResult result = target.failover(topology); + + assertTrue(result.isConnected()); + assertTrue(result.isNewHost()); + assertSame(result.getNewConnection(), mockNewWriterConnection); + assertEquals(4, result.getTopology().size()); + assertEquals("new-writer-host", result.getTopology().get(0).getHost()); + + verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); + assertEquals(HostAvailability.AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); + } + + /** + * Verify that writer failover handler fails to re-connect to any writer node. + * + *

Topology: no changes seen by task A, changes to [new-writer, reader-A, reader-B] for taskB. + * TaskA: fail to re-connect to writer due to failover timeout. + * TaskB: successfully connect to readerA and then fail to connect to writer due to failover timeout. + * Expected test result: no connection. + */ + @Test + public void testFailedToConnect_failoverTimeout() throws SQLException { + when(mockPluginService.forceConnect(refEq(writer), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(30000); + return mockWriterConnection; + }); + when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))) + .thenAnswer( + (Answer) + invocation -> { + Thread.sleep(30000); + return mockNewWriterConnection; + }); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); + + final long startTimeNano = System.nanoTime(); + final WriterFailoverResult result = target.failover(topology); + final long durationNano = System.nanoTime() - startTimeNano; + + assertFalse(result.isConnected()); + assertFalse(result.isNewHost()); + + verify(mockPluginService, atLeastOnce()).forceRefreshHostList(any(Connection.class)); + + // 5s is a max allowed failover timeout; add 1s for inaccurate measurements + assertTrue(TimeUnit.NANOSECONDS.toMillis(durationNano) < 6000); + } + + /** + * Verify that writer failover handler fails to re-connect to any writer node. + * + *

Topology: changes to [new-writer, reader-A, reader-B] for taskB. + * TaskA: fail to re-connect to writer due to exception. + * TaskB: successfully connect to readerA and then fail to connect to writer due to exception. + * Expected test result: no connection. + */ + @Test + public void testFailedToConnect_taskAException_taskBWriterException() throws SQLException { + final SQLException exception = new SQLException("exception", "08S01", null); + when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenThrow(exception); + when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenReturn(mockReaderAConnection); + when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); + when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenThrow(exception); + when(mockPluginService.isNetworkException(eq(exception), any())).thenReturn(true); + + when(mockPluginService.getAllHosts()).thenReturn(newTopology); + + when(mockReaderFailoverHandler.getReaderConnection(ArgumentMatchers.anyList())) + .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getFailoverRestrictions()).thenReturn(EnumSet.noneOf(FailoverRestriction.class)); + + final ClusterAwareWriterFailoverHandler target = getSpyFailoverHandler(5000, 2000, 2000); + final WriterFailoverResult result = target.failover(topology); + + assertFalse(result.isConnected()); + assertFalse(result.isNewHost()); + + assertEquals(HostAvailability.NOT_AVAILABLE, target.getHostAvailabilityMap().get(newWriterHost.getHost())); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index 0ecddc7d0..8274235b9 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -1,447 +1,445 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.failover; -// -// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.atLeastOnce; -// import static org.mockito.Mockito.doNothing; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.doThrow; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.EnumSet; -// import java.util.HashMap; -// import java.util.HashSet; -// import java.util.List; -// import java.util.Map; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.ValueSource; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.NodeChangeOptions; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.RdsUrlType; -// import software.amazon.jdbc.util.SqlState; -// import software.amazon.jdbc.util.connection.ConnectionService; -// import software.amazon.jdbc.util.telemetry.GaugeCallable; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.util.telemetry.TelemetryGauge; -// -// class FailoverConnectionPluginTest { -// -// private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; -// private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; -// private static final Object[] EMPTY_ARGS = {}; -// private final List defaultHosts = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("writer").port(1234).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("reader1").port(1234).role(HostRole.READER).build()); -// -// @Mock FullServicesContainer mockContainer; -// @Mock ConnectionService mockConnectionService; -// @Mock PluginService mockPluginService; -// @Mock Connection mockConnection; -// @Mock HostSpec mockHostSpec; -// @Mock HostListProviderService mockHostListProviderService; -// @Mock AuroraHostListProvider mockHostListProvider; -// @Mock JdbcCallable mockInitHostProviderFunc; -// @Mock ReaderFailoverHandler mockReaderFailoverHandler; -// @Mock WriterFailoverHandler mockWriterFailoverHandler; -// @Mock ReaderFailoverResult mockReaderResult; -// @Mock WriterFailoverResult mockWriterResult; -// @Mock JdbcCallable mockSqlFunction; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock TelemetryCounter mockTelemetryCounter; -// @Mock TelemetryGauge mockTelemetryGauge; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// -// -// private final Properties properties = new Properties(); -// private FailoverConnectionPlugin spyPlugin; -// private AutoCloseable closeable; -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @BeforeEach -// void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// -// when(mockContainer.getPluginService()).thenReturn(mockPluginService); -// when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); -// when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); -// when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockPluginService.getHosts()).thenReturn(defaultHosts); -// when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); -// when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); -// when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); -// when(mockWriterResult.isConnected()).thenReturn(true); -// when(mockWriterResult.getTopology()).thenReturn(defaultHosts); -// when(mockReaderResult.isConnected()).thenReturn(true); -// -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); -// // noinspection unchecked -// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); -// -// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); -// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); -// -// properties.clear(); -// } -// -// @Test -// void test_notifyNodeListChanged_withFailoverDisabled() throws SQLException { -// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); -// final Map> changes = new HashMap<>(); -// -// initializePlugin(); -// spyPlugin.notifyNodeListChanged(changes); -// -// verify(mockPluginService, never()).getCurrentHostSpec(); -// verify(mockHostSpec, never()).getAliases(); -// } -// -// @Test -// void test_notifyNodeListChanged_withValidConnectionNotInTopology() throws SQLException { -// final Map> changes = new HashMap<>(); -// changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); -// changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); -// -// initializePlugin(); -// spyPlugin.notifyNodeListChanged(changes); -// -// when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); -// -// verify(mockPluginService).getCurrentHostSpec(); -// verify(mockHostSpec, never()).getAliases(); -// } -// -// @Test -// void test_updateTopology() throws SQLException { -// initializePlugin(); -// -// // Test updateTopology with failover disabled -// spyPlugin.setRdsUrlType(RdsUrlType.RDS_PROXY); -// spyPlugin.updateTopology(false); -// verify(mockPluginService, never()).forceRefreshHostList(); -// verify(mockPluginService, never()).refreshHostList(); -// -// // Test updateTopology with no connection -// when(mockPluginService.getCurrentHostSpec()).thenReturn(null); -// spyPlugin.updateTopology(false); -// verify(mockPluginService, never()).forceRefreshHostList(); -// verify(mockPluginService, never()).refreshHostList(); -// -// // Test updateTopology with closed connection -// when(mockConnection.isClosed()).thenReturn(true); -// spyPlugin.updateTopology(false); -// verify(mockPluginService, never()).forceRefreshHostList(); -// verify(mockPluginService, never()).refreshHostList(); -// } -// -// @ParameterizedTest -// @ValueSource(booleans = {true, false}) -// void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { -// -// when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); -// when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); -// when(mockConnection.isClosed()).thenReturn(false); -// initializePlugin(); -// spyPlugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); -// -// spyPlugin.updateTopology(forceUpdate); -// if (forceUpdate) { -// verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); -// } else { -// verify(mockPluginService, atLeastOnce()).refreshHostList(); -// } -// } -// -// @Test -// void test_failover_failoverWriter() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(true); -// -// initializePlugin(); -// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); -// spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; -// -// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); -// verify(spyPlugin).failoverWriter(); -// } -// -// @Test -// void test_failover_failoverReader() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(false); -// -// initializePlugin(); -// doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); -// spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; -// -// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); -// verify(spyPlugin).failoverReader(eq(mockHostSpec)); -// } -// -// @Test -// void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); -// when(mockReaderResult.isConnected()).thenReturn(true); -// when(mockReaderResult.getConnection()).thenReturn(mockConnection); -// when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); -// -// initializePlugin(); -// spyPlugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// (connectionService) -> mockReaderFailoverHandler, -// (connectionService) -> mockWriterFailoverHandler); -// -// final FailoverConnectionPlugin spyPlugin = spy(this.spyPlugin); -// doNothing().when(spyPlugin).updateTopology(true); -// -// assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); -// -// verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); -// verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); -// } -// -// @Test -// void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { -// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") -// .build(); -// final List hosts = Collections.singletonList(hostSpec); -// -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); -// when(mockPluginService.getAllHosts()).thenReturn(hosts); -// when(mockPluginService.getHosts()).thenReturn(hosts); -// when(mockReaderResult.getException()).thenReturn(new SQLException()); -// when(mockReaderResult.getHost()).thenReturn(hostSpec); -// -// initializePlugin(); -// spyPlugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// (connectionService) -> mockReaderFailoverHandler, -// (connectionService) -> mockWriterFailoverHandler); -// -// assertThrows(SQLException.class, () -> spyPlugin.failoverReader(null)); -// verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); -// } -// -// @Test -// void test_failoverWriter_failedFailover_throwsException() throws SQLException { -// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") -// .build(); -// final List hosts = Collections.singletonList(hostSpec); -// -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockPluginService.getAllHosts()).thenReturn(hosts); -// when(mockPluginService.getHosts()).thenReturn(hosts); -// when(mockWriterResult.getException()).thenReturn(new SQLException()); -// -// initializePlugin(); -// spyPlugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// (connectionService) -> mockReaderFailoverHandler, -// (connectionService) -> mockWriterFailoverHandler); -// -// assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); -// verify(mockWriterFailoverHandler).failover(eq(hosts)); -// } -// -// @Test -// void test_failoverWriter_failedFailover_withNoResult() throws SQLException { -// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") -// .build(); -// final List hosts = Collections.singletonList(hostSpec); -// -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// when(mockPluginService.getAllHosts()).thenReturn(hosts); -// when(mockPluginService.getHosts()).thenReturn(hosts); -// when(mockWriterResult.isConnected()).thenReturn(false); -// -// initializePlugin(); -// spyPlugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// (connectionService) -> mockReaderFailoverHandler, -// (connectionService) -> mockWriterFailoverHandler); -// -// final SQLException exception = assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); -// assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); -// -// verify(mockWriterFailoverHandler).failover(eq(hosts)); -// verify(mockWriterResult, never()).getNewConnection(); -// verify(mockWriterResult, never()).getTopology(); -// } -// -// @Test -// void test_failoverWriter_successFailover() throws SQLException { -// when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); -// -// initializePlugin(); -// spyPlugin.initHostProvider( -// mockHostListProviderService, -// mockInitHostProviderFunc, -// (connectionService) -> mockReaderFailoverHandler, -// (connectionService) -> mockWriterFailoverHandler); -// -// final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverWriter()); -// assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); -// -// verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); -// } -// -// @Test -// void test_invalidCurrentConnection_withNoConnection() throws SQLException { -// when(mockPluginService.getCurrentConnection()).thenReturn(null); -// initializePlugin(); -// spyPlugin.invalidateCurrentConnection(); -// -// verify(mockPluginService, never()).getCurrentHostSpec(); -// } -// -// @Test -// void test_invalidateCurrentConnection_inTransaction() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(true); -// when(mockHostSpec.getHost()).thenReturn("host"); -// when(mockHostSpec.getPort()).thenReturn(123); -// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); -// -// initializePlugin(); -// spyPlugin.invalidateCurrentConnection(); -// verify(mockConnection).rollback(); -// -// // Assert SQL exceptions thrown during rollback do not get propagated. -// doThrow(new SQLException()).when(mockConnection).rollback(); -// assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); -// } -// -// @Test -// void test_invalidateCurrentConnection_notInTransaction() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(false); -// when(mockHostSpec.getHost()).thenReturn("host"); -// when(mockHostSpec.getPort()).thenReturn(123); -// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); -// -// initializePlugin(); -// spyPlugin.invalidateCurrentConnection(); -// -// verify(mockPluginService).isInTransaction(); -// } -// -// @Test -// void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { -// when(mockPluginService.isInTransaction()).thenReturn(false); -// when(mockConnection.isClosed()).thenReturn(false); -// when(mockHostSpec.getHost()).thenReturn("host"); -// when(mockHostSpec.getPort()).thenReturn(123); -// when(mockHostSpec.getRole()).thenReturn(HostRole.READER); -// -// initializePlugin(); -// spyPlugin.invalidateCurrentConnection(); -// -// doThrow(new SQLException()).when(mockConnection).close(); -// assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); -// -// verify(mockConnection, times(2)).isClosed(); -// verify(mockConnection, times(2)).close(); -// } -// -// @Test -// void test_execute_withFailoverDisabled() throws SQLException { -// properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); -// initializePlugin(); -// -// spyPlugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// mockSqlFunction, -// EMPTY_ARGS); -// -// verify(mockSqlFunction).call(); -// verify(mockHostListProvider, never()).getRdsUrlType(); -// } -// -// @Test -// void test_execute_withDirectExecute() throws SQLException { -// initializePlugin(); -// spyPlugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// "close", -// mockSqlFunction, -// EMPTY_ARGS); -// verify(mockSqlFunction).call(); -// verify(mockHostListProvider, never()).getRdsUrlType(); -// } -// -// private void initializePlugin() throws SQLException { -// spyPlugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); -// spyPlugin.setWriterFailoverHandler(mockWriterFailoverHandler); -// spyPlugin.setReaderFailoverHandler(mockReaderFailoverHandler); -// doReturn(mockConnectionService).when(spyPlugin).getConnectionService(); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.failover; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.NodeChangeOptions; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.RdsUrlType; +import software.amazon.jdbc.util.SqlState; +import software.amazon.jdbc.util.connection.ConnectionService; +import software.amazon.jdbc.util.telemetry.GaugeCallable; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +class FailoverConnectionPluginTest { + + private static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; + private static final String MONITOR_METHOD_NAME = "Connection.executeQuery"; + private static final Object[] EMPTY_ARGS = {}; + private final List defaultHosts = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("writer").port(1234).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("reader1").port(1234).role(HostRole.READER).build()); + + @Mock FullServicesContainer mockContainer; + @Mock ConnectionService mockConnectionService; + @Mock PluginService mockPluginService; + @Mock Connection mockConnection; + @Mock HostSpec mockHostSpec; + @Mock HostListProviderService mockHostListProviderService; + @Mock AuroraHostListProvider mockHostListProvider; + @Mock JdbcCallable mockInitHostProviderFunc; + @Mock ReaderFailoverHandler mockReaderFailoverHandler; + @Mock WriterFailoverHandler mockWriterFailoverHandler; + @Mock ReaderFailoverResult mockReaderResult; + @Mock WriterFailoverResult mockWriterResult; + @Mock JdbcCallable mockSqlFunction; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock TelemetryContext mockTelemetryContext; + @Mock TelemetryCounter mockTelemetryCounter; + @Mock TelemetryGauge mockTelemetryGauge; + @Mock TargetDriverDialect mockTargetDriverDialect; + + + private final Properties properties = new Properties(); + private FailoverConnectionPlugin spyPlugin; + private AutoCloseable closeable; + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @BeforeEach + void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + + when(mockContainer.getPluginService()).thenReturn(mockPluginService); + when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); + when(mockHostListProvider.getRdsUrlType()).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); + when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockPluginService.getHosts()).thenReturn(defaultHosts); + when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); + when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); + when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); + when(mockWriterResult.isConnected()).thenReturn(true); + when(mockWriterResult.getTopology()).thenReturn(defaultHosts); + when(mockReaderResult.isConnected()).thenReturn(true); + + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); + // noinspection unchecked + when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); + + when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); + when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); + + properties.clear(); + } + + @Test + void test_notifyNodeListChanged_withFailoverDisabled() throws SQLException { + properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); + final Map> changes = new HashMap<>(); + + initializePlugin(); + spyPlugin.notifyNodeListChanged(changes); + + verify(mockPluginService, never()).getCurrentHostSpec(); + verify(mockHostSpec, never()).getAliases(); + } + + @Test + void test_notifyNodeListChanged_withValidConnectionNotInTopology() throws SQLException { + final Map> changes = new HashMap<>(); + changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); + changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); + + initializePlugin(); + spyPlugin.notifyNodeListChanged(changes); + + when(mockHostSpec.getUrl()).thenReturn("cluster-url/"); + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("instance"))); + + verify(mockPluginService).getCurrentHostSpec(); + verify(mockHostSpec, never()).getAliases(); + } + + @Test + void test_updateTopology() throws SQLException { + initializePlugin(); + + // Test updateTopology with failover disabled + spyPlugin.setRdsUrlType(RdsUrlType.RDS_PROXY); + spyPlugin.updateTopology(false); + verify(mockPluginService, never()).forceRefreshHostList(); + verify(mockPluginService, never()).refreshHostList(); + + // Test updateTopology with no connection + when(mockPluginService.getCurrentHostSpec()).thenReturn(null); + spyPlugin.updateTopology(false); + verify(mockPluginService, never()).forceRefreshHostList(); + verify(mockPluginService, never()).refreshHostList(); + + // Test updateTopology with closed connection + when(mockConnection.isClosed()).thenReturn(true); + spyPlugin.updateTopology(false); + verify(mockPluginService, never()).forceRefreshHostList(); + verify(mockPluginService, never()).refreshHostList(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { + + when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); + when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); + when(mockConnection.isClosed()).thenReturn(false); + initializePlugin(); + spyPlugin.setRdsUrlType(RdsUrlType.RDS_INSTANCE); + + spyPlugin.updateTopology(forceUpdate); + if (forceUpdate) { + verify(mockPluginService, atLeastOnce()).forceRefreshHostList(); + } else { + verify(mockPluginService, atLeastOnce()).refreshHostList(); + } + } + + @Test + void test_failover_failoverWriter() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(true); + + initializePlugin(); + doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverWriter(); + spyPlugin.failoverMode = FailoverMode.STRICT_WRITER; + + assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); + verify(spyPlugin).failoverWriter(); + } + + @Test + void test_failover_failoverReader() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(false); + + initializePlugin(); + doThrow(FailoverSuccessSQLException.class).when(spyPlugin).failoverReader(eq(mockHostSpec)); + spyPlugin.failoverMode = FailoverMode.READER_OR_WRITER; + + assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failover(mockHostSpec)); + verify(spyPlugin).failoverReader(eq(mockHostSpec)); + } + + @Test + void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLException { + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); + when(mockReaderResult.isConnected()).thenReturn(true); + when(mockReaderResult.getConnection()).thenReturn(mockConnection); + when(mockReaderResult.getHost()).thenReturn(defaultHosts.get(1)); + + initializePlugin(); + spyPlugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + () -> mockReaderFailoverHandler, + () -> mockWriterFailoverHandler); + + final FailoverConnectionPlugin spyPlugin = spy(this.spyPlugin); + doNothing().when(spyPlugin).updateTopology(true); + + assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverReader(mockHostSpec)); + + verify(mockReaderFailoverHandler).failover(eq(defaultHosts), eq(mockHostSpec)); + verify(mockPluginService).setCurrentConnection(eq(mockConnection), eq(defaultHosts.get(1))); + } + + @Test + void test_failoverReader_withNoFailedHostSpec_withException() throws SQLException { + final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") + .build(); + final List hosts = Collections.singletonList(hostSpec); + + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); + when(mockPluginService.getAllHosts()).thenReturn(hosts); + when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockReaderResult.getException()).thenReturn(new SQLException()); + when(mockReaderResult.getHost()).thenReturn(hostSpec); + + initializePlugin(); + spyPlugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + () -> mockReaderFailoverHandler, + () -> mockWriterFailoverHandler); + + assertThrows(SQLException.class, () -> spyPlugin.failoverReader(null)); + verify(mockReaderFailoverHandler).failover(eq(hosts), eq(null)); + } + + @Test + void test_failoverWriter_failedFailover_throwsException() throws SQLException { + final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") + .build(); + final List hosts = Collections.singletonList(hostSpec); + + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockPluginService.getAllHosts()).thenReturn(hosts); + when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockWriterResult.getException()).thenReturn(new SQLException()); + + initializePlugin(); + spyPlugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + () -> mockReaderFailoverHandler, + () -> mockWriterFailoverHandler); + + assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); + verify(mockWriterFailoverHandler).failover(eq(hosts)); + } + + @Test + void test_failoverWriter_failedFailover_withNoResult() throws SQLException { + final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") + .build(); + final List hosts = Collections.singletonList(hostSpec); + + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockPluginService.getAllHosts()).thenReturn(hosts); + when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockWriterResult.isConnected()).thenReturn(false); + + initializePlugin(); + spyPlugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + () -> mockReaderFailoverHandler, + () -> mockWriterFailoverHandler); + + final SQLException exception = assertThrows(SQLException.class, () -> spyPlugin.failoverWriter()); + assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), exception.getSQLState()); + + verify(mockWriterFailoverHandler).failover(eq(hosts)); + verify(mockWriterResult, never()).getNewConnection(); + verify(mockWriterResult, never()).getTopology(); + } + + @Test + void test_failoverWriter_successFailover() throws SQLException { + when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + + initializePlugin(); + spyPlugin.initHostProvider( + mockHostListProviderService, + mockInitHostProviderFunc, + () -> mockReaderFailoverHandler, + () -> mockWriterFailoverHandler); + + final SQLException exception = assertThrows(FailoverSuccessSQLException.class, () -> spyPlugin.failoverWriter()); + assertEquals(SqlState.COMMUNICATION_LINK_CHANGED.getState(), exception.getSQLState()); + + verify(mockWriterFailoverHandler).failover(eq(defaultHosts)); + } + + @Test + void test_invalidCurrentConnection_withNoConnection() throws SQLException { + when(mockPluginService.getCurrentConnection()).thenReturn(null); + initializePlugin(); + spyPlugin.invalidateCurrentConnection(); + + verify(mockPluginService, never()).getCurrentHostSpec(); + } + + @Test + void test_invalidateCurrentConnection_inTransaction() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockHostSpec.getHost()).thenReturn("host"); + when(mockHostSpec.getPort()).thenReturn(123); + when(mockHostSpec.getRole()).thenReturn(HostRole.READER); + + initializePlugin(); + spyPlugin.invalidateCurrentConnection(); + verify(mockConnection).rollback(); + + // Assert SQL exceptions thrown during rollback do not get propagated. + doThrow(new SQLException()).when(mockConnection).rollback(); + assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); + } + + @Test + void test_invalidateCurrentConnection_notInTransaction() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockHostSpec.getHost()).thenReturn("host"); + when(mockHostSpec.getPort()).thenReturn(123); + when(mockHostSpec.getRole()).thenReturn(HostRole.READER); + + initializePlugin(); + spyPlugin.invalidateCurrentConnection(); + + verify(mockPluginService).isInTransaction(); + } + + @Test + void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.isClosed()).thenReturn(false); + when(mockHostSpec.getHost()).thenReturn("host"); + when(mockHostSpec.getPort()).thenReturn(123); + when(mockHostSpec.getRole()).thenReturn(HostRole.READER); + + initializePlugin(); + spyPlugin.invalidateCurrentConnection(); + + doThrow(new SQLException()).when(mockConnection).close(); + assertDoesNotThrow(() -> spyPlugin.invalidateCurrentConnection()); + + verify(mockConnection, times(2)).isClosed(); + verify(mockConnection, times(2)).close(); + } + + @Test + void test_execute_withFailoverDisabled() throws SQLException { + properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); + initializePlugin(); + + spyPlugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + mockSqlFunction, + EMPTY_ARGS); + + verify(mockSqlFunction).call(); + verify(mockHostListProvider, never()).getRdsUrlType(); + } + + @Test + void test_execute_withDirectExecute() throws SQLException { + initializePlugin(); + spyPlugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + "close", + mockSqlFunction, + EMPTY_ARGS); + verify(mockSqlFunction).call(); + verify(mockHostListProvider, never()).getRdsUrlType(); + } + + private void initializePlugin() throws SQLException { + spyPlugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); + spyPlugin.setWriterFailoverHandler(mockWriterFailoverHandler); + spyPlugin.setReaderFailoverHandler(mockReaderFailoverHandler); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index 9251b27e5..02bd5060e 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -1,314 +1,315 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.util.monitoring; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertNotEquals; -// import static org.junit.jupiter.api.Assertions.assertNotNull; -// import static org.junit.jupiter.api.Assertions.assertNull; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyInt; -// import static org.mockito.Mockito.doNothing; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.spy; -// -// import java.sql.SQLException; -// import java.util.Collections; -// import java.util.HashSet; -// import java.util.Properties; -// import java.util.concurrent.TimeUnit; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.Assertions; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.ConnectionProvider; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.connection.ConnectionService; -// import software.amazon.jdbc.util.events.EventPublisher; -// import software.amazon.jdbc.util.storage.StorageService; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// -// class MonitorServiceImplTest { -// @Mock StorageService mockStorageService; -// @Mock ConnectionService mockConnectionService; -// @Mock ConnectionProvider mockConnectionProvider; -// @Mock TelemetryFactory mockTelemetryFactory; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// @Mock Dialect mockDbDialect; -// @Mock EventPublisher mockPublisher; -// MonitorServiceImpl spyMonitorService; -// private AutoCloseable closeable; -// -// @BeforeEach -// void setUp() { -// closeable = MockitoAnnotations.openMocks(this); -// spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); -// doNothing().when(spyMonitorService).initCleanupThread(anyInt()); -// -// try { -// doReturn(mockConnectionService).when(spyMonitorService) -// .getConnectionService(any(), any(), any(), any(), any(), any(), any(), any()); -// } catch (SQLException e) { -// Assertions.fail( -// "Encountered exception while stubbing MonitorServiceImpl#getConnectionService: " + e.getMessage()); -// } -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// spyMonitorService.releaseResources(); -// } -// -// @Test -// public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// TimeUnit.MINUTES.toNanos(1), -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// "jdbc:postgresql://somehost/somedb", -// "someProtocol", -// mockTargetDriverDialect, -// mockDbDialect, -// new Properties(), -// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) -// ); -// -// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(storedMonitor); -// assertEquals(monitor, storedMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// -// monitor.state.set(MonitorState.ERROR); -// spyMonitorService.checkMonitors(); -// -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// -// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(newMonitor); -// assertNotEquals(monitor, newMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, newMonitor.getState()); -// } -// -// @Test -// public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// 1, // heartbeat times out immediately -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// "jdbc:postgresql://somehost/somedb", -// "someProtocol", -// mockTargetDriverDialect, -// mockDbDialect, -// new Properties(), -// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) -// ); -// -// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(storedMonitor); -// assertEquals(monitor, storedMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// -// // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. -// spyMonitorService.checkMonitors(); -// -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// -// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(newMonitor); -// assertNotEquals(monitor, newMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, newMonitor.getState()); -// } -// -// @Test -// public void testMonitorExpired() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms -// TimeUnit.MINUTES.toNanos(1), -// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this -// // indicates it is not being used. -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// "jdbc:postgresql://somehost/somedb", -// "someProtocol", -// mockTargetDriverDialect, -// mockDbDialect, -// new Properties(), -// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) -// ); -// -// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(storedMonitor); -// assertEquals(monitor, storedMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// -// // checkMonitors() should detect the expiration timeout and stop/remove the monitor. -// spyMonitorService.checkMonitors(); -// -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// -// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// // monitor should have been removed when checkMonitors() was called. -// assertNull(newMonitor); -// } -// -// @Test -// public void testMonitorMismatch() { -// assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( -// CustomEndpointMonitorImpl.class, -// "testMonitor", -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// "jdbc:postgresql://somehost/somedb", -// "someProtocol", -// mockTargetDriverDialect, -// mockDbDialect, -// new Properties(), -// // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor -// // service should detect this and throw an exception. -// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) -// )); -// } -// -// @Test -// public void testRemove() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// TimeUnit.MINUTES.toNanos(1), -// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this -// // indicates it is not being used. -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// "jdbc:postgresql://somehost/somedb", -// "someProtocol", -// mockTargetDriverDialect, -// mockDbDialect, -// new Properties(), -// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) -// ); -// assertNotNull(monitor); -// -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); -// assertEquals(monitor, removedMonitor); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// } -// -// @Test -// public void testStopAndRemove() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// TimeUnit.MINUTES.toNanos(1), -// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this -// // indicates it is not being used. -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// "jdbc:postgresql://somehost/somedb", -// "someProtocol", -// mockTargetDriverDialect, -// mockDbDialect, -// new Properties(), -// (connectionService, pluginService) -> new NoOpMonitor(spyMonitorService, 30) -// ); -// assertNotNull(monitor); -// -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// spyMonitorService.stopAndRemove(NoOpMonitor.class, key); -// assertNull(spyMonitorService.get(NoOpMonitor.class, key)); -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// } -// -// static class NoOpMonitor extends AbstractMonitor { -// protected NoOpMonitor( -// MonitorService monitorService, -// long terminationTimeoutSec) { -// super(terminationTimeoutSec); -// } -// -// @Override -// public void monitor() { -// // do nothing. -// } -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.util.monitoring; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; + +import java.sql.SQLException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.events.EventPublisher; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class MonitorServiceImplTest { + @Mock FullServicesContainer mockServicesContainer; + @Mock StorageService mockStorageService; + @Mock ConnectionProvider mockConnectionProvider; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock Dialect mockDbDialect; + @Mock EventPublisher mockPublisher; + String URL = "jdbc:postgresql://somehost/somedb"; + String PROTOCOL = "someProtocol"; + Properties props = new Properties(); + MonitorServiceImpl spyMonitorService; + private AutoCloseable closeable; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); + doNothing().when(spyMonitorService).initCleanupThread(anyInt()); + doReturn(mockServicesContainer).when(spyMonitorService).getNewServicesContainer( + eq(mockStorageService), + eq(mockConnectionProvider), + eq(mockTelemetryFactory), + eq(URL), + eq(PROTOCOL), + eq(mockTargetDriverDialect), + eq(mockDbDialect), + eq(props)); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + spyMonitorService.releaseResources(); + } + + @Test + public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + TimeUnit.MINUTES.toNanos(1), + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + URL, + PROTOCOL, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(storedMonitor); + assertEquals(monitor, storedMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, monitor.getState()); + + monitor.state.set(MonitorState.ERROR); + spyMonitorService.checkMonitors(); + + assertEquals(MonitorState.STOPPED, monitor.getState()); + + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(newMonitor); + assertNotEquals(monitor, newMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, newMonitor.getState()); + } + + @Test + public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + 1, // heartbeat times out immediately + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + URL, + PROTOCOL, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(storedMonitor); + assertEquals(monitor, storedMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, monitor.getState()); + + // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. + spyMonitorService.checkMonitors(); + + assertEquals(MonitorState.STOPPED, monitor.getState()); + + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(newMonitor); + assertNotEquals(monitor, newMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, newMonitor.getState()); + } + + @Test + public void testMonitorExpired() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms + TimeUnit.MINUTES.toNanos(1), + // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this + // indicates it is not being used. + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + URL, + PROTOCOL, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(storedMonitor); + assertEquals(monitor, storedMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, monitor.getState()); + + // checkMonitors() should detect the expiration timeout and stop/remove the monitor. + spyMonitorService.checkMonitors(); + + assertEquals(MonitorState.STOPPED, monitor.getState()); + + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); + // monitor should have been removed when checkMonitors() was called. + assertNull(newMonitor); + } + + @Test + public void testMonitorMismatch() { + assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( + CustomEndpointMonitorImpl.class, + "testMonitor", + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + URL, + PROTOCOL, + mockTargetDriverDialect, + mockDbDialect, + props, + // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor + // service should detect this and throw an exception. + (mockServicesContainer) -> new NoOpMonitor(30) + )); + } + + @Test + public void testRemove() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + TimeUnit.MINUTES.toNanos(1), + // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this + // indicates it is not being used. + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + URL, + PROTOCOL, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + assertNotNull(monitor); + + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); + assertEquals(monitor, removedMonitor); + assertEquals(MonitorState.RUNNING, monitor.getState()); + } + + @Test + public void testStopAndRemove() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + TimeUnit.MINUTES.toNanos(1), + // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this + // indicates it is not being used. + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + URL, + PROTOCOL, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + assertNotNull(monitor); + + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + spyMonitorService.stopAndRemove(NoOpMonitor.class, key); + assertNull(spyMonitorService.get(NoOpMonitor.class, key)); + assertEquals(MonitorState.STOPPED, monitor.getState()); + } + + static class NoOpMonitor extends AbstractMonitor { + protected NoOpMonitor(long terminationTimeoutSec) { + super(terminationTimeoutSec); + } + + @Override + public void monitor() { + // do nothing. + } + } +} From 85252d00d562a07ea1c6dccaa7a59c5958d7ab9f Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 15 Sep 2025 14:40:25 -0700 Subject: [PATCH 38/54] Cleanup --- .../main/java/software/amazon/jdbc/PartialPluginService.java | 4 ---- .../main/java/software/amazon/jdbc/util/ServiceUtility.java | 4 ++++ .../amazon/jdbc/util/connection/ConnectionServiceImpl.java | 4 ++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index c6b494651..06957b61f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -109,10 +109,6 @@ public PartialPluginService( @NonNull final Dialect dbDialect, @Nullable final ConfigurationProfile configurationProfile) { this.servicesContainer = servicesContainer; - this.servicesContainer.setHostListProviderService(this); - this.servicesContainer.setPluginService(this); - this.servicesContainer.setPluginManagerService(this); - this.pluginManager = servicesContainer.getConnectionPluginManager(); this.props = props; this.originalUrl = originalUrl; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java index 84d6be614..13d5d9cd6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java @@ -84,6 +84,10 @@ public FullServicesContainer createServiceContainer( dbDialect ); + servicesContainer.setHostListProviderService(partialPluginService); + servicesContainer.setPluginService(partialPluginService); + servicesContainer.setPluginManagerService(partialPluginService); + pluginManager.init(servicesContainer, props, partialPluginService, null); return servicesContainer; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 52744d494..cc509c26a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -68,6 +68,10 @@ public ConnectionServiceImpl( dbDialect ); + servicesContainer.setHostListProviderService(partialPluginService); + servicesContainer.setPluginService(partialPluginService); + servicesContainer.setPluginManagerService(partialPluginService); + this.pluginService = partialPluginService; this.pluginManager.init(servicesContainer, props, partialPluginService, null); } From c465a006660f9047d0510824041d6bc20e015b61 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 15 Sep 2025 15:45:11 -0700 Subject: [PATCH 39/54] Cleanup --- .../ClusterAwareReaderFailoverHandler.java | 4 ---- .../util/connection/ConnectionService.java | 12 ++++++++++-- .../connection/ConnectionServiceImpl.java | 19 +++++++++++++++++-- .../FailoverConnectionPluginTest.java | 12 +++++------- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 8cfed799a..e58378ddb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -33,20 +33,16 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.logging.Logger; -import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.FullServicesContainerImpl; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.ServiceUtility; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionService; /** * An implementation of ReaderFailoverHandler. diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java index 5a6f4fcff..088f5d9d8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java @@ -22,6 +22,12 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PluginService; +/** + * @deprecated This interface is deprecated and will be removed in a future version. Use + * {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by + * {@link PluginService#forceConnect} instead. + */ +@Deprecated public interface ConnectionService { /** * Creates an auxiliary connection. Auxiliary connections are driver-internal connections that accomplish various @@ -31,8 +37,10 @@ public interface ConnectionService { * @param props the properties for the auxiliary connection. * @return a new connection to the given host using the given props. * @throws SQLException if an error occurs while opening the connection. + * @deprecated Use {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by + * {@link PluginService#forceConnect} instead. */ - Connection open(HostSpec hostSpec, Properties props) throws SQLException; + @Deprecated Connection open(HostSpec hostSpec, Properties props) throws SQLException; - PluginService getPluginService(); + @Deprecated PluginService getPluginService(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index c50770233..4ba674d59 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -16,6 +16,7 @@ package software.amazon.jdbc.util.connection; +import com.mchange.v2.util.PropertiesUtils; import java.sql.Connection; import java.sql.SQLException; import java.util.Properties; @@ -28,15 +29,26 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.FullServicesContainerImpl; +import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; +/** + * @deprecated This class is deprecated and will be removed in a future version. Use + * {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by + * {@link PluginService#forceConnect} instead. + */ +@Deprecated public class ConnectionServiceImpl implements ConnectionService { protected final String targetDriverProtocol; protected final ConnectionPluginManager pluginManager; protected final PluginService pluginService; + /** + * @deprecated Use {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} instead. + */ + @Deprecated public ConnectionServiceImpl( StorageService storageService, MonitorService monitorService, @@ -58,9 +70,10 @@ public ConnectionServiceImpl( telemetryFactory); servicesContainer.setConnectionPluginManager(this.pluginManager); + Properties propsCopy = PropertyUtils.copyProperties(props); PartialPluginService partialPluginService = new PartialPluginService( servicesContainer, - props, + propsCopy, originalUrl, this.targetDriverProtocol, driverDialect, @@ -72,15 +85,17 @@ public ConnectionServiceImpl( servicesContainer.setPluginManagerService(partialPluginService); this.pluginService = partialPluginService; - this.pluginManager.init(servicesContainer, props, partialPluginService, null); + this.pluginManager.init(servicesContainer, propsCopy, partialPluginService, null); } @Override + @Deprecated public Connection open(HostSpec hostSpec, Properties props) throws SQLException { return this.pluginManager.forceConnect(this.targetDriverProtocol, hostSpec, props, true, null); } @Override + @Deprecated public PluginService getPluginService() { return this.pluginService; } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index 8274235b9..637793772 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -63,7 +63,6 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.SqlState; -import software.amazon.jdbc.util.connection.ConnectionService; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -82,7 +81,6 @@ class FailoverConnectionPluginTest { .host("reader1").port(1234).role(HostRole.READER).build()); @Mock FullServicesContainer mockContainer; - @Mock ConnectionService mockConnectionService; @Mock PluginService mockPluginService; @Mock Connection mockConnection; @Mock HostSpec mockHostSpec; @@ -143,7 +141,7 @@ void init() throws SQLException { } @Test - void test_notifyNodeListChanged_withFailoverDisabled() throws SQLException { + void test_notifyNodeListChanged_withFailoverDisabled() { properties.setProperty(FailoverConnectionPlugin.ENABLE_CLUSTER_AWARE_FAILOVER.name, "false"); final Map> changes = new HashMap<>(); @@ -155,7 +153,7 @@ void test_notifyNodeListChanged_withFailoverDisabled() throws SQLException { } @Test - void test_notifyNodeListChanged_withValidConnectionNotInTopology() throws SQLException { + void test_notifyNodeListChanged_withValidConnectionNotInTopology() { final Map> changes = new HashMap<>(); changes.put("cluster-host/", EnumSet.of(NodeChangeOptions.NODE_DELETED)); changes.put("instance/", EnumSet.of(NodeChangeOptions.NODE_ADDED)); @@ -351,7 +349,7 @@ void test_failoverWriter_successFailover() throws SQLException { } @Test - void test_invalidCurrentConnection_withNoConnection() throws SQLException { + void test_invalidCurrentConnection_withNoConnection() { when(mockPluginService.getCurrentConnection()).thenReturn(null); initializePlugin(); spyPlugin.invalidateCurrentConnection(); @@ -376,7 +374,7 @@ void test_invalidateCurrentConnection_inTransaction() throws SQLException { } @Test - void test_invalidateCurrentConnection_notInTransaction() throws SQLException { + void test_invalidateCurrentConnection_notInTransaction() { when(mockPluginService.isInTransaction()).thenReturn(false); when(mockHostSpec.getHost()).thenReturn("host"); when(mockHostSpec.getPort()).thenReturn(123); @@ -437,7 +435,7 @@ void test_execute_withDirectExecute() throws SQLException { verify(mockHostListProvider, never()).getRdsUrlType(); } - private void initializePlugin() throws SQLException { + private void initializePlugin() { spyPlugin = spy(new FailoverConnectionPlugin(mockContainer, properties)); spyPlugin.setWriterFailoverHandler(mockWriterFailoverHandler); spyPlugin.setReaderFailoverHandler(mockReaderFailoverHandler); From b2a107ded5b4cc2e1b5904014be8c546623cf516 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 15 Sep 2025 15:51:03 -0700 Subject: [PATCH 40/54] cleanup --- .../amazon/ReadWriteSplittingPostgresExample.java | 2 -- .../ReadWriteSplittingSpringJdbcTemplateMySQLExample.java | 1 - .../src/main/java/software/amazon/HikariExample.java | 1 - .../src/main/java/example/spring/Config.java | 7 ------- .../main/java/software/amazon/jdbc/HostSpecBuilder.java | 1 - .../amazon/jdbc/authentication/AwsCredentialsManager.java | 1 - .../jdbc/hostavailability/HostAvailabilityStrategy.java | 2 -- .../plugin/federatedauth/CredentialsProviderFactory.java | 1 - .../jdbc/plugin/limitless/LimitlessRouterService.java | 2 -- .../amazon/jdbc/plugin/limitless/LimitlessRouters.java | 1 - .../fastestresponse/HostResponseTimeServiceImpl.java | 8 ++------ .../jdbc/targetdriverdialect/MariadbDriverHelper.java | 1 - .../software/amazon/jdbc/util/ConnectionUrlParser.java | 1 - .../java/software/amazon/jdbc/util/PropertyUtils.java | 1 - .../jdbc/util/connection/ConnectionServiceImpl.java | 1 - .../amazon/jdbc/util/monitoring/MonitorServiceImpl.java | 1 - .../amazon/jdbc/util/storage/SlidingExpirationCache.java | 1 - .../container/tests/hibernate/HibernateTests.java | 1 - .../src/test/java/integration/host/TestEnvironment.java | 1 - .../src/test/java/integration/util/ContainerHelper.java | 1 - .../amazon/jdbc/ConnectionPluginChainBuilderTests.java | 1 - .../src/test/java/software/amazon/jdbc/DialectTests.java | 1 - .../jdbc/authentication/AwsCredentialsManagerTest.java | 1 - .../jdbc/plugin/iam/IamAuthConnectionPluginTest.java | 4 ---- .../readwritesplitting/ReadWriteSplittingPluginTest.java | 1 - .../amazon/jdbc/states/SessionStateServiceImplTests.java | 1 - .../amazon/jdbc/util/ConnectionUrlParserTest.java | 1 - 27 files changed, 2 insertions(+), 44 deletions(-) diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingPostgresExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingPostgresExample.java index 1a4b42cf7..c4a273dc9 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingPostgresExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingPostgresExample.java @@ -23,8 +23,6 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.Properties; -import software.amazon.jdbc.ConnectionProviderManager; -import software.amazon.jdbc.HikariPooledConnectionProvider; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.plugin.failover.FailoverFailedSQLException; diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingSpringJdbcTemplateMySQLExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingSpringJdbcTemplateMySQLExample.java index ead361b58..ea914e72e 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingSpringJdbcTemplateMySQLExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingSpringJdbcTemplateMySQLExample.java @@ -16,7 +16,6 @@ package software.amazon; -import com.mysql.cj.jdbc.MysqlDataSource; import com.zaxxer.hikari.HikariDataSource; import java.sql.Connection; import java.sql.SQLException; diff --git a/examples/HikariExample/src/main/java/software/amazon/HikariExample.java b/examples/HikariExample/src/main/java/software/amazon/HikariExample.java index 9bbfa03c9..8f6607798 100644 --- a/examples/HikariExample/src/main/java/software/amazon/HikariExample.java +++ b/examples/HikariExample/src/main/java/software/amazon/HikariExample.java @@ -22,7 +22,6 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; -import java.util.Properties; public class HikariExample { diff --git a/examples/SpringHibernateBalancedReaderTwoDataSourceExample/src/main/java/example/spring/Config.java b/examples/SpringHibernateBalancedReaderTwoDataSourceExample/src/main/java/example/spring/Config.java index bc404df52..d6b30ffa8 100644 --- a/examples/SpringHibernateBalancedReaderTwoDataSourceExample/src/main/java/example/spring/Config.java +++ b/examples/SpringHibernateBalancedReaderTwoDataSourceExample/src/main/java/example/spring/Config.java @@ -16,10 +16,7 @@ package example.spring; -import com.zaxxer.hikari.HikariConfig; -import java.util.Arrays; import java.util.Properties; -import java.util.concurrent.TimeUnit; import javax.persistence.EntityManagerFactory; import javax.sql.DataSource; import org.hibernate.exception.JDBCConnectionException; @@ -40,10 +37,6 @@ import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.annotation.EnableTransactionManagement; -import software.amazon.jdbc.HikariPooledConnectionProvider; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -import software.amazon.jdbc.profile.ConfigurationProfilePresetCodes; @Configuration @EnableTransactionManagement diff --git a/wrapper/src/main/java/software/amazon/jdbc/HostSpecBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/HostSpecBuilder.java index a84920637..ec6abdbee 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HostSpecBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HostSpecBuilder.java @@ -17,7 +17,6 @@ package software.amazon.jdbc; import java.sql.Timestamp; -import java.time.Instant; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; diff --git a/wrapper/src/main/java/software/amazon/jdbc/authentication/AwsCredentialsManager.java b/wrapper/src/main/java/software/amazon/jdbc/authentication/AwsCredentialsManager.java index eb463a7aa..c562dd48e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/authentication/AwsCredentialsManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/authentication/AwsCredentialsManager.java @@ -23,7 +23,6 @@ import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.StringUtils; public class AwsCredentialsManager { diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostavailability/HostAvailabilityStrategy.java b/wrapper/src/main/java/software/amazon/jdbc/hostavailability/HostAvailabilityStrategy.java index 659fdae2b..9f04b54c6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostavailability/HostAvailabilityStrategy.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostavailability/HostAvailabilityStrategy.java @@ -16,8 +16,6 @@ package software.amazon.jdbc.hostavailability; -import software.amazon.jdbc.AwsWrapperProperty; - public interface HostAvailabilityStrategy { void setHostAvailability(HostAvailability hostAvailability); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java index a43396bf9..655fbdda4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java @@ -16,7 +16,6 @@ package software.amazon.jdbc.plugin.federatedauth; -import java.io.Closeable; import java.sql.SQLException; import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterService.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterService.java index 2d3d87a08..1e3a04560 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterService.java @@ -16,9 +16,7 @@ package software.amazon.jdbc.plugin.limitless; -import java.sql.Connection; import java.sql.SQLException; -import java.util.List; import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouters.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouters.java index 0793dbcff..ea7cab3ce 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouters.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouters.java @@ -20,7 +20,6 @@ import java.util.Objects; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.hostlistprovider.Topology; public class LimitlessRouters { private final @NonNull List hosts; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java index f6b9cf177..ee157b3ea 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java @@ -20,16 +20,13 @@ import java.util.ArrayList; import java.util.List; import java.util.Properties; -import java.util.Set; import java.util.logging.Logger; -import java.util.stream.Collectors; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.storage.SlidingExpirationCacheWithCleanupThread; public class HostResponseTimeServiceImpl implements HostResponseTimeService { @@ -73,7 +70,7 @@ public void setHosts(final @NonNull List hosts) { // Going through all hosts in the topology and trying to find new ones. this.hosts.stream() // hostSpec is not in the set of hosts that already being monitored - .filter(hostSpec ->!Utils.containsHostAndPort(oldHosts, hostSpec.getHostAndPort())) + .filter(hostSpec -> !Utils.containsHostAndPort(oldHosts, hostSpec.getHostAndPort())) .forEach(hostSpec -> { try { this.servicesContainer.getMonitorService().runIfAbsent( @@ -88,8 +85,7 @@ public void setHosts(final @NonNull List hosts) { this.pluginService.getDialect(), this.props, (servicesContainer) -> - new NodeResponseTimeMonitor(pluginService, hostSpec, this.props, - this.intervalMs)); + new NodeResponseTimeMonitor(pluginService, hostSpec, this.props, this.intervalMs)); } catch (SQLException e) { LOGGER.warning( Messages.get("HostResponseTimeServiceImpl.errorStartingMonitor", new Object[] {hostSpec.getUrl(), e})); diff --git a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbDriverHelper.java b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbDriverHelper.java index 8c93819c3..d67638b6b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbDriverHelper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbDriverHelper.java @@ -18,7 +18,6 @@ import static software.amazon.jdbc.util.ConnectionUrlBuilder.buildUrl; -import com.mysql.cj.jdbc.Driver; import java.sql.DriverManager; import java.sql.SQLException; import java.util.Collections; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ConnectionUrlParser.java b/wrapper/src/main/java/software/amazon/jdbc/util/ConnectionUrlParser.java index 81323335b..435907141 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ConnectionUrlParser.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ConnectionUrlParser.java @@ -16,7 +16,6 @@ package software.amazon.jdbc.util; -import com.fasterxml.jackson.databind.annotation.JsonAppend.Prop; import java.util.ArrayList; import java.util.List; import java.util.Map; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/PropertyUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/PropertyUtils.java index ce2f66fad..632aaac6c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/PropertyUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/PropertyUtils.java @@ -28,7 +28,6 @@ import java.util.Set; import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; -import software.amazon.awssdk.services.rds.endpoints.internal.Value.Bool; import software.amazon.jdbc.AwsWrapperProperty; import software.amazon.jdbc.PropertyDefinition; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 4ba674d59..14f3c0166 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -16,7 +16,6 @@ package software.amazon.jdbc.util.connection; -import com.mchange.v2.util.PropertiesUtils; import java.sql.Connection; import java.sql.SQLException; import java.util.Properties; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index e75e3c2f4..5e0df4817 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -41,7 +41,6 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.ServiceUtility; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/storage/SlidingExpirationCache.java b/wrapper/src/main/java/software/amazon/jdbc/util/storage/SlidingExpirationCache.java index 7f670866e..604b603f2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/storage/SlidingExpirationCache.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/storage/SlidingExpirationCache.java @@ -20,7 +20,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; diff --git a/wrapper/src/test/java/integration/container/tests/hibernate/HibernateTests.java b/wrapper/src/test/java/integration/container/tests/hibernate/HibernateTests.java index 01d263f6b..fd4e1eac9 100644 --- a/wrapper/src/test/java/integration/container/tests/hibernate/HibernateTests.java +++ b/wrapper/src/test/java/integration/container/tests/hibernate/HibernateTests.java @@ -20,7 +20,6 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import integration.DatabaseEngine; -import integration.DatabaseEngineDeployment; import integration.DriverHelper; import integration.TestEnvironmentFeatures; import integration.container.ConnectionStringHelper; diff --git a/wrapper/src/test/java/integration/host/TestEnvironment.java b/wrapper/src/test/java/integration/host/TestEnvironment.java index 424b741ea..ce639a152 100644 --- a/wrapper/src/test/java/integration/host/TestEnvironment.java +++ b/wrapper/src/test/java/integration/host/TestEnvironment.java @@ -50,7 +50,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; -import org.testcontainers.containers.BindMode; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.Network; import org.testcontainers.containers.ToxiproxyContainer; diff --git a/wrapper/src/test/java/integration/util/ContainerHelper.java b/wrapper/src/test/java/integration/util/ContainerHelper.java index bc977dcda..6745badd8 100644 --- a/wrapper/src/test/java/integration/util/ContainerHelper.java +++ b/wrapper/src/test/java/integration/util/ContainerHelper.java @@ -42,7 +42,6 @@ import org.testcontainers.containers.ToxiproxyContainer; import org.testcontainers.containers.output.FrameConsumerResultCallback; import org.testcontainers.containers.output.OutputFrame; -import org.testcontainers.containers.startupcheck.StartupCheckStrategy; import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.images.builder.ImageFromDockerfile; diff --git a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginChainBuilderTests.java b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginChainBuilderTests.java index 84efe5ba3..4352d0622 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginChainBuilderTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginChainBuilderTests.java @@ -28,7 +28,6 @@ import java.util.HashSet; import java.util.List; import java.util.Properties; -import java.util.Set; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java index 0caefc973..4170f8556 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java @@ -35,7 +35,6 @@ import org.mockito.MockitoAnnotations; import software.amazon.jdbc.dialect.AuroraMysqlDialect; import software.amazon.jdbc.dialect.AuroraPgDialect; -import software.amazon.jdbc.dialect.DialectManager; import software.amazon.jdbc.dialect.MariaDbDialect; import software.amazon.jdbc.dialect.MysqlDialect; import software.amazon.jdbc.dialect.PgDialect; diff --git a/wrapper/src/test/java/software/amazon/jdbc/authentication/AwsCredentialsManagerTest.java b/wrapper/src/test/java/software/amazon/jdbc/authentication/AwsCredentialsManagerTest.java index 6028f5706..56d24e076 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/authentication/AwsCredentialsManagerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/authentication/AwsCredentialsManagerTest.java @@ -17,7 +17,6 @@ package software.amazon.jdbc.authentication; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java index 92a0c0ff6..a872ac96c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java @@ -44,17 +44,13 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.rds.RdsUtilities; -import software.amazon.awssdk.services.rds.TestDefaultRdsUtilities; import software.amazon.jdbc.Driver; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.HostSpecBuilder; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.authentication.AwsCredentialsManager; import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.plugin.TokenInfo; import software.amazon.jdbc.util.RdsUtils; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index ff7282483..c7c7bdc1b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -45,7 +45,6 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HikariPooledConnectionProvider; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; diff --git a/wrapper/src/test/java/software/amazon/jdbc/states/SessionStateServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/states/SessionStateServiceImplTests.java index 8ab28907d..ed391f4cf 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/states/SessionStateServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/states/SessionStateServiceImplTests.java @@ -17,7 +17,6 @@ package software.amazon.jdbc.states; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/ConnectionUrlParserTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/ConnectionUrlParserTest.java index 082d2f78b..056fccd78 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/ConnectionUrlParserTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/ConnectionUrlParserTest.java @@ -32,7 +32,6 @@ import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; class ConnectionUrlParserTest { From e6357455fd49b3429cf50470fe21bfaa39df56df Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 15 Sep 2025 15:58:59 -0700 Subject: [PATCH 41/54] Fix checkstyle --- .../util/connection/ConnectionService.java | 20 +++++++++--- .../connection/ConnectionServiceImpl.java | 8 +++-- .../monitoring/MonitorServiceImplTest.java | 32 +++++++++---------- 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java index 088f5d9d8..e0748bc72 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionService.java @@ -21,11 +21,14 @@ import java.util.Properties; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.FullServicesContainer; /** + * A service used to open new connections for internal driver use. + * * @deprecated This interface is deprecated and will be removed in a future version. Use - * {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by - * {@link PluginService#forceConnect} instead. + * {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by + * {@link PluginService#forceConnect} instead. */ @Deprecated public interface ConnectionService { @@ -40,7 +43,16 @@ public interface ConnectionService { * @deprecated Use {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by * {@link PluginService#forceConnect} instead. */ - @Deprecated Connection open(HostSpec hostSpec, Properties props) throws SQLException; + @Deprecated + Connection open(HostSpec hostSpec, Properties props) throws SQLException; - @Deprecated PluginService getPluginService(); + /** + * Get the {@link PluginService} associated with this {@link ConnectionService}. + * + * @return the {@link PluginService} associated with this {@link ConnectionService} + * @deprecated Use {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by + * {@link FullServicesContainer#getPluginService()} instead. + */ + @Deprecated + PluginService getPluginService(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 14f3c0166..9e8356eb9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -34,9 +34,11 @@ import software.amazon.jdbc.util.telemetry.TelemetryFactory; /** + * A service used to open new connections for internal driver use. + * * @deprecated This class is deprecated and will be removed in a future version. Use - * {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by - * {@link PluginService#forceConnect} instead. + * {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} followed by + * {@link PluginService#forceConnect} instead. */ @Deprecated public class ConnectionServiceImpl implements ConnectionService { @@ -45,6 +47,8 @@ public class ConnectionServiceImpl implements ConnectionService { protected final PluginService pluginService; /** + * Constructs a {@link ConnectionServiceImpl} instance. + * * @deprecated Use {@link software.amazon.jdbc.util.ServiceUtility#createServiceContainer} instead. */ @Deprecated diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index 02bd5060e..450b494a7 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -54,8 +54,8 @@ class MonitorServiceImplTest { @Mock TargetDriverDialect mockTargetDriverDialect; @Mock Dialect mockDbDialect; @Mock EventPublisher mockPublisher; - String URL = "jdbc:postgresql://somehost/somedb"; - String PROTOCOL = "someProtocol"; + String url = "jdbc:postgresql://somehost/somedb"; + String protocol = "someProtocol"; Properties props = new Properties(); MonitorServiceImpl spyMonitorService; private AutoCloseable closeable; @@ -69,8 +69,8 @@ void setUp() throws SQLException { eq(mockStorageService), eq(mockConnectionProvider), eq(mockTelemetryFactory), - eq(URL), - eq(PROTOCOL), + eq(url), + eq(protocol), eq(mockTargetDriverDialect), eq(mockDbDialect), eq(props)); @@ -98,8 +98,8 @@ public void testMonitorError_monitorReCreated() throws SQLException, Interrupted mockStorageService, mockTelemetryFactory, mockConnectionProvider, - URL, - PROTOCOL, + url, + protocol, mockTargetDriverDialect, mockDbDialect, props, @@ -142,8 +142,8 @@ public void testMonitorStuck_monitorReCreated() throws SQLException, Interrupted mockStorageService, mockTelemetryFactory, mockConnectionProvider, - URL, - PROTOCOL, + url, + protocol, mockTargetDriverDialect, mockDbDialect, props, @@ -188,8 +188,8 @@ public void testMonitorExpired() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - URL, - PROTOCOL, + url, + protocol, mockTargetDriverDialect, mockDbDialect, props, @@ -221,8 +221,8 @@ public void testMonitorMismatch() { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - URL, - PROTOCOL, + url, + protocol, mockTargetDriverDialect, mockDbDialect, props, @@ -251,8 +251,8 @@ public void testRemove() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - URL, - PROTOCOL, + url, + protocol, mockTargetDriverDialect, mockDbDialect, props, @@ -286,8 +286,8 @@ public void testStopAndRemove() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - URL, - PROTOCOL, + url, + protocol, mockTargetDriverDialect, mockDbDialect, props, From b695b867a693d91553690e4b113cc3aa08261108 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 15 Sep 2025 18:02:30 -0700 Subject: [PATCH 42/54] wip --- .../ConnectionPluginManagerBenchmarks.java | 46 +- .../testplugin/BenchmarkPlugin.java | 6 +- .../jdbc/C3P0PooledConnectionProvider.java | 3 +- .../amazon/jdbc/ConnectionPlugin.java | 5 +- .../amazon/jdbc/ConnectionPluginManager.java | 9 +- .../amazon/jdbc/ConnectionProvider.java | 8 +- .../jdbc/ConnectionProviderManager.java | 21 +- .../jdbc/DataSourceConnectionProvider.java | 15 +- .../amazon/jdbc/DriverConnectionProvider.java | 15 +- .../jdbc/HikariPooledConnectionProvider.java | 6 +- .../software/amazon/jdbc/PluginService.java | 7 + .../amazon/jdbc/PluginServiceImpl.java | 109 ++-- .../amazon/jdbc/dialect/DialectManager.java | 25 +- .../amazon/jdbc/dialect/DialectProvider.java | 6 +- .../dialect/HostListProviderSupplier.java | 5 +- .../amazon/jdbc/dialect/MysqlDialect.java | 4 +- .../ConnectionStringHostListProvider.java | 24 +- .../util/connection/ConnectionContext.java | 63 +++ .../jdbc/wrapper/ConnectionWrapper.java | 37 +- .../HikariPooledConnectionProviderTest.java | 494 +++++++++--------- 20 files changed, 453 insertions(+), 455 deletions(-) create mode 100644 wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java index 6cb455805..e1078539a 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; +import java.lang.annotation.Target; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -69,6 +70,7 @@ import software.amazon.jdbc.profile.ConfigurationProfileBuilder; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; @@ -88,8 +90,9 @@ public class ConnectionPluginManagerBenchmarks { private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; private static final String FIELD_SERVER_ID = "SERVER_ID"; private static final String FIELD_SESSION_ID = "SESSION_ID"; - private Properties propertiesWithoutPlugins; - private Properties propertiesWithPlugins; + private static final String url = "protocol//url"; + private ConnectionContext pluginsContext; + private ConnectionContext noPluginsContext; private ConnectionPluginManager pluginManager; private ConnectionPluginManager pluginManagerWithNoPlugins; @@ -98,6 +101,7 @@ public class ConnectionPluginManagerBenchmarks { @Mock FullServicesContainer mockServicesContainer; @Mock PluginService mockPluginService; @Mock PluginManagerService mockPluginManagerService; + @Mock TargetDriverDialect mockDriverDialect; @Mock TelemetryFactory mockTelemetryFactory; @Mock HostListProviderService mockHostListProvider; @Mock Connection mockConnection; @@ -153,24 +157,26 @@ public void setUpIteration() throws Exception { .withPluginFactories(pluginFactories) .build(); - propertiesWithoutPlugins = new Properties(); - propertiesWithoutPlugins.setProperty(PropertyDefinition.PLUGINS.name, ""); + Properties noPluginsProps = new Properties(); + noPluginsProps.setProperty(PropertyDefinition.PLUGINS.name, ""); + this.noPluginsContext = new ConnectionContext(url, mockDriverDialect, noPluginsProps); - propertiesWithPlugins = new Properties(); - propertiesWithPlugins.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); - propertiesWithPlugins.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); + Properties pluginsProps = new Properties(); + pluginsProps.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); + pluginsProps.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); + this.pluginsContext = new ConnectionContext(url, mockDriverDialect, pluginsProps); - TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(propertiesWithPlugins); + TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(pluginsProps); pluginManager = new ConnectionPluginManager(mockConnectionProvider, null, mockConnectionWrapper, telemetryFactory); - pluginManager.init(mockServicesContainer, propertiesWithPlugins, mockPluginManagerService, configurationProfile); + pluginManager.init(mockServicesContainer, pluginsProps, mockPluginManagerService, configurationProfile); pluginManagerWithNoPlugins = new ConnectionPluginManager(mockConnectionProvider, null, mockConnectionWrapper, telemetryFactory); - pluginManagerWithNoPlugins.init(mockServicesContainer, propertiesWithoutPlugins, mockPluginManagerService, null); + pluginManagerWithNoPlugins.init(mockServicesContainer, noPluginsProps, mockPluginManagerService, null); } @TearDown(Level.Iteration) @@ -182,7 +188,7 @@ public void tearDownIteration() throws Exception { public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws SQLException { final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, mockConnectionWrapper, mockTelemetryFactory); - manager.init(mockServicesContainer, propertiesWithoutPlugins, mockPluginManagerService, configurationProfile); + manager.init(mockServicesContainer, this.noPluginsContext.getProps(), mockPluginManagerService, configurationProfile); return manager; } @@ -190,7 +196,7 @@ public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws public ConnectionPluginManager initConnectionPluginManagerWithPlugins() throws SQLException { final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, mockConnectionWrapper, mockTelemetryFactory); - manager.init(mockServicesContainer, propertiesWithPlugins, mockPluginManagerService, configurationProfile); + manager.init(mockServicesContainer, this.pluginsContext.getProps(), mockPluginManagerService, configurationProfile); return manager; } @@ -199,7 +205,7 @@ public Connection connectWithPlugins() throws SQLException { return pluginManager.connect( "driverProtocol", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), - propertiesWithPlugins, + this.pluginsContext.getProps(), true, null); } @@ -209,7 +215,7 @@ public Connection connectWithNoPlugins() throws SQLException { return pluginManagerWithNoPlugins.connect( "driverProtocol", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), - propertiesWithoutPlugins, + this.noPluginsContext.getProps(), true, null); } @@ -240,21 +246,13 @@ public Integer executeWithNoPlugins() { @Benchmark public ConnectionPluginManager initHostProvidersWithPlugins() throws SQLException { - pluginManager.initHostProvider( - "protocol", - "url", - propertiesWithPlugins, - mockHostListProvider); + pluginManager.initHostProvider(this.pluginsContext, mockHostListProvider); return pluginManager; } @Benchmark public ConnectionPluginManager initHostProvidersWithNoPlugins() throws SQLException { - pluginManagerWithNoPlugins.initHostProvider( - "protocol", - "url", - propertiesWithoutPlugins, - mockHostListProvider); + pluginManagerWithNoPlugins.initHostProvider(this.noPluginsContext, mockHostListProvider); return pluginManager; } diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java index ffed10f77..f238fbe06 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java @@ -37,6 +37,7 @@ import software.amazon.jdbc.OldConnectionSuggestedAction; import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.connection.ConnectionContext; public class BenchmarkPlugin implements ConnectionPlugin, CanReleaseResources { final List resources = new ArrayList<>(); @@ -93,10 +94,11 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin } @Override - public void initHostProvider(String driverProtocol, String initialUrl, Properties props, + public void initHostProvider( + ConnectionContext connectionContext, HostListProviderService hostListProviderService, JdbcCallable initHostProviderFunc) { - LOGGER.finer(() -> String.format("initHostProvider=''%s''", initialUrl)); + LOGGER.finer(() -> String.format("initHostProvider=''%s''", connectionContext.getUrl())); resources.add("initHostProvider"); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java index 07b8570fd..16c67469a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java @@ -35,6 +35,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.SlidingExpirationCache; public class C3P0PooledConnectionProvider implements PooledConnectionProvider, CanReleaseResources { @@ -55,7 +56,7 @@ public class C3P0PooledConnectionProvider implements PooledConnectionProvider, C protected static final long poolExpirationCheckNanos = TimeUnit.MINUTES.toNanos(30); @Override - public boolean acceptsUrl(@NonNull String protocol, @NonNull HostSpec hostSpec, @NonNull Properties props) { + public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { return true; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java index d2d72b05c..a7b71e170 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Properties; import java.util.Set; +import software.amazon.jdbc.util.connection.ConnectionContext; /** * Interface for connection plugins. This class implements ways to execute a JDBC method and to clean up resources used @@ -136,9 +137,7 @@ HostSpec getHostSpecByStrategy(final List hosts, final HostRole role, throws SQLException, UnsupportedOperationException; void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties props, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException; diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 4a4c7b20e..cba4007fd 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -52,6 +52,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; @@ -553,10 +554,7 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin } public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties props, - final HostListProviderService hostListProviderService) + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService) throws SQLException { TelemetryContext context = this.telemetryFactory.openTelemetryContext( "initHostProvider", TelemetryTraceLevel.NESTED); @@ -566,8 +564,7 @@ public void initHostProvider( JdbcMethod.INITHOSTPROVIDER, (PluginPipeline) (plugin, func) -> { - plugin.initHostProvider( - driverProtocol, initialUrl, props, hostListProviderService, func); + plugin.initHostProvider(connectionContext, hostListProviderService, func); return null; }, () -> { diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java index 8fddcf269..ebb475e7d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java @@ -22,8 +22,10 @@ import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.units.qual.N; import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.connection.ConnectionContext; /** * Implement this interface in order to handle the physical connection creation process. @@ -34,14 +36,12 @@ public interface ConnectionProvider { * properties. Some ConnectionProvider implementations may not be able to handle certain URL * types or properties. * - * @param protocol the connection protocol (example "jdbc:mysql://") + * @param connectionContext the connection info for the original connection. * @param hostSpec the HostSpec containing the host-port information for the host to connect to - * @param props the Properties to use for the connection * @return true if this ConnectionProvider can provide connections for the given URL, otherwise * return false */ - boolean acceptsUrl( - @NonNull String protocol, @NonNull HostSpec hostSpec, @NonNull Properties props); + boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec); /** * Indicates whether the selection strategy is supported by the connection provider. diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java index a8d1e6100..92e5259d5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java @@ -23,6 +23,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.cleanup.CanReleaseResources; +import software.amazon.jdbc.util.connection.ConnectionContext; public class ConnectionProviderManager { @@ -32,11 +33,10 @@ public class ConnectionProviderManager { /** * {@link ConnectionProviderManager} constructor. * - * @param defaultProvider the default {@link ConnectionProvider} to use if a non-default - * ConnectionProvider has not been set or the non-default - * ConnectionProvider has been set but does not accept a requested URL + * @param defaultProvider the default {@link ConnectionProvider} to use if a non-default + * ConnectionProvider has not been set or the non-default + * ConnectionProvider has been set but does not accept a requested URL * @param effectiveConnProvider the non-default {@link ConnectionProvider} to use - * */ public ConnectionProviderManager( final ConnectionProvider defaultProvider, @@ -66,21 +66,19 @@ public static void setConnectionProvider(ConnectionProvider connProvider) { * non-default ConnectionProvider will be returned. Otherwise, the default ConnectionProvider will * be returned. See {@link ConnectionProvider#acceptsUrl} for more info. * - * @param driverProtocol the driver protocol that will be used to establish the connection - * @param host the host info for the connection that will be established - * @param props the connection properties for the connection that will be established + * @param connectionContext the connection info for the original connection. + * @param host the host info for the connection that will be established * @return the {@link ConnectionProvider} to use to establish a connection using the given driver * protocol, host details, and properties */ - public ConnectionProvider getConnectionProvider( - String driverProtocol, HostSpec host, Properties props) { + public ConnectionProvider getConnectionProvider(ConnectionContext connectionContext, HostSpec host) { final ConnectionProvider customConnectionProvider = Driver.getCustomConnectionProvider(); - if (customConnectionProvider != null && customConnectionProvider.acceptsUrl(driverProtocol, host, props)) { + if (customConnectionProvider != null && customConnectionProvider.acceptsUrl(connectionContext, host)) { return customConnectionProvider; } - if (this.effectiveConnProvider != null && this.effectiveConnProvider.acceptsUrl(driverProtocol, host, props)) { + if (this.effectiveConnProvider != null && this.effectiveConnProvider.acceptsUrl(connectionContext, host)) { return this.effectiveConnProvider; } @@ -190,7 +188,6 @@ public static void releaseResources() { * for every brand-new database connection. * * @param func A function that initialize a new connection - * * @deprecated @see Driver#setConnectionInitFunc(ConnectionInitFunc) */ @Deprecated diff --git a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java index 0231739d0..e465722f7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java @@ -38,6 +38,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; /** * This class is a basic implementation of {@link ConnectionProvider} interface. It creates and @@ -67,20 +68,8 @@ public DataSourceConnectionProvider(final @NonNull DataSource dataSource) { this.dataSourceClassName = dataSource.getClass().getName(); } - /** - * Indicates whether this ConnectionProvider can provide connections for the given host and - * properties. Some ConnectionProvider implementations may not be able to handle certain URL - * types or properties. - * - * @param protocol The connection protocol (example "jdbc:mysql://") - * @param hostSpec The HostSpec containing the host-port information for the host to connect to - * @param props The Properties to use for the connection - * @return true if this ConnectionProvider can provide connections for the given URL, otherwise - * return false - */ @Override - public boolean acceptsUrl( - @NonNull String protocol, @NonNull HostSpec hostSpec, @NonNull Properties props) { + public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { return true; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java index 135eabc4b..e3dc6735b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java @@ -35,6 +35,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; /** * This class is a basic implementation of {@link ConnectionProvider} interface. It creates and @@ -63,20 +64,8 @@ public DriverConnectionProvider(final java.sql.Driver driver) { this.targetDriverClassName = driver.getClass().getName(); } - /** - * Indicates whether this ConnectionProvider can provide connections for the given host and - * properties. Some ConnectionProvider implementations may not be able to handle certain URL - * types or properties. - * - * @param protocol The connection protocol (example "jdbc:mysql://") - * @param hostSpec The HostSpec containing the host-port information for the host to connect to - * @param props The Properties to use for the connection - * @return true if this ConnectionProvider can provide connections for the given URL, otherwise - * return false - */ @Override - public boolean acceptsUrl( - @NonNull String protocol, @NonNull HostSpec hostSpec, @NonNull Properties props) { + public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { return true; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java index 47e1968cc..31de59c9a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java @@ -43,6 +43,7 @@ import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.SlidingExpirationCache; public class HikariPooledConnectionProvider implements PooledConnectionProvider, @@ -204,10 +205,9 @@ public HikariPooledConnectionProvider( @Override - public boolean acceptsUrl( - @NonNull String protocol, @NonNull HostSpec hostSpec, @NonNull Properties props) { + public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { if (this.acceptsUrlFunc != null) { - return this.acceptsUrlFunc.acceptsUrl(hostSpec, props); + return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionContext.getProps()); } final RdsUrlType urlType = rdsUtils.identifyRdsType(hostSpec.getHost()); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index 6c9180650..e1af0469e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -29,6 +29,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.states.SessionStateService; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; /** @@ -81,6 +82,12 @@ EnumSet setCurrentConnection( HostSpec getInitialConnectionHostSpec(); + /** + * Get the {@link ConnectionContext} for the current original connection. + * @return the {@link ConnectionContext} for the current original connection. + */ + ConnectionContext getConnectionContext(); + String getOriginalUrl(); /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index f6d266564..f00a75777 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -53,6 +53,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.CacheMap; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -69,20 +70,16 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, protected static final long DEFAULT_STATUS_CACHE_EXPIRE_NANO = TimeUnit.MINUTES.toNanos(60); protected final ConnectionPluginManager pluginManager; - private final Properties props; - private final String originalUrl; - private final String driverProtocol; + protected final ConnectionContext connectionContext; protected volatile HostListProvider hostListProvider; protected List allHosts = new ArrayList<>(); protected Connection currentConnection; protected HostSpec currentHostSpec; protected HostSpec initialConnectionHostSpec; - private boolean isInTransaction; - private final ExceptionManager exceptionManager; + protected boolean isInTransaction; + protected final ExceptionManager exceptionManager; protected final @Nullable ExceptionHandler exceptionHandler; protected final DialectProvider dialectProvider; - protected Dialect dialect; - protected TargetDriverDialect targetDriverDialect; protected @Nullable final ConfigurationProfile configurationProfile; protected final ConnectionProviderManager connectionProviderManager; @@ -91,40 +88,20 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, protected final ReentrantLock connectionSwitchLock = new ReentrantLock(); public PluginServiceImpl( - @NonNull final FullServicesContainer servicesContainer, - @NonNull final Properties props, - @NonNull final String originalUrl, - @NonNull final String targetDriverProtocol, - @NonNull final TargetDriverDialect targetDriverDialect) + @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionContext connectionContext) throws SQLException { - - this( - servicesContainer, - new ExceptionManager(), - props, - originalUrl, - targetDriverProtocol, - null, - targetDriverDialect, - null, - null); + this(servicesContainer, new ExceptionManager(), connectionContext, null, null, null); } public PluginServiceImpl( @NonNull final FullServicesContainer servicesContainer, - @NonNull final Properties props, - @NonNull final String originalUrl, - @NonNull final String targetDriverProtocol, - @NonNull final TargetDriverDialect targetDriverDialect, + @NonNull final ConnectionContext connectionContext, @Nullable final ConfigurationProfile configurationProfile) throws SQLException { this( servicesContainer, new ExceptionManager(), - props, - originalUrl, - targetDriverProtocol, + connectionContext, null, - targetDriverDialect, configurationProfile, null); } @@ -132,37 +109,32 @@ public PluginServiceImpl( public PluginServiceImpl( @NonNull final FullServicesContainer servicesContainer, @NonNull final ExceptionManager exceptionManager, - @NonNull final Properties props, - @NonNull final String originalUrl, - @NonNull final String targetDriverProtocol, + @NonNull final ConnectionContext connectionContext, @Nullable final DialectProvider dialectProvider, - @NonNull final TargetDriverDialect targetDriverDialect, @Nullable final ConfigurationProfile configurationProfile, @Nullable final SessionStateService sessionStateService) throws SQLException { this.servicesContainer = servicesContainer; this.pluginManager = servicesContainer.getConnectionPluginManager(); - this.props = props; - this.originalUrl = originalUrl; - this.driverProtocol = targetDriverProtocol; + this.connectionContext = connectionContext; this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; this.dialectProvider = dialectProvider != null ? dialectProvider : new DialectManager(this); - this.targetDriverDialect = targetDriverDialect; this.connectionProviderManager = new ConnectionProviderManager( this.pluginManager.getDefaultConnProvider(), this.pluginManager.getEffectiveConnProvider()); this.sessionStateService = sessionStateService != null ? sessionStateService - : new SessionStateServiceImpl(this, this.props); + : new SessionStateServiceImpl(this, this.connectionContext.getProps()); this.exceptionHandler = this.configurationProfile != null && this.configurationProfile.getExceptionHandler() != null ? this.configurationProfile.getExceptionHandler() : null; - this.dialect = this.configurationProfile != null && this.configurationProfile.getDialect() != null + Dialect dialect = this.configurationProfile != null && this.configurationProfile.getDialect() != null ? this.configurationProfile.getDialect() - : this.dialectProvider.getDialect(this.driverProtocol, this.originalUrl, this.props); + : this.dialectProvider.getDialect(this.connectionContext); + this.connectionContext.setDbDialect(dialect); } @Override @@ -212,9 +184,14 @@ public HostSpec getInitialConnectionHostSpec() { return this.initialConnectionHostSpec; } + @Override + public ConnectionContext getConnectionContext() { + return this.connectionContext; + } + @Override public String getOriginalUrl() { - return this.originalUrl; + return this.connectionContext.getUrl(); } @Override @@ -256,13 +233,13 @@ public ConnectionProvider getDefaultConnectionProvider() { public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = - this.connectionProviderManager.getConnectionProvider(this.driverProtocol, host, props); + this.connectionProviderManager.getConnectionProvider(this.connectionContext, host); return (connectionProvider instanceof PooledConnectionProvider); } @Override public String getDriverProtocol() { - return this.driverProtocol; + return this.getConnectionContext().getProtocol(); } @Override @@ -313,7 +290,8 @@ public EnumSet setCurrentConnection( this.sessionStateService.applyCurrentSessionState(connection); this.setInTransaction(false); - if (isInTransaction && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.props)) { + if (isInTransaction + && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionContext.getProps())) { try { oldConnection.rollback(); } catch (final SQLException e) { @@ -614,7 +592,7 @@ public Connection connect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.connect( - this.driverProtocol, hostSpec, props, this.currentConnection == null, pluginToSkip); + this.connectionContext.getProtocol(), hostSpec, props, this.currentConnection == null, pluginToSkip); } @Override @@ -632,7 +610,7 @@ public Connection forceConnect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.forceConnect( - this.driverProtocol, hostSpec, props, this.currentConnection == null, pluginToSkip); + this.connectionContext.getProtocol(), hostSpec, props, this.currentConnection == null, pluginToSkip); } private void updateHostAvailability(final List hosts) { @@ -665,7 +643,7 @@ public void releaseResources() { @Override @Deprecated public boolean isNetworkException(Throwable throwable) { - return this.isNetworkException(throwable, this.targetDriverDialect); + return this.isNetworkException(throwable, this.connectionContext.getDriverDialect()); } @Override @@ -673,7 +651,7 @@ public boolean isNetworkException(final Throwable throwable, @Nullable TargetDri if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(throwable, targetDriverDialect); } - return this.exceptionManager.isNetworkException(this.dialect, throwable, targetDriverDialect); + return this.exceptionManager.isNetworkException(this.connectionContext.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -681,13 +659,13 @@ public boolean isNetworkException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(sqlState); } - return this.exceptionManager.isNetworkException(this.dialect, sqlState); + return this.exceptionManager.isNetworkException(this.connectionContext.getDbDialect(), sqlState); } @Override @Deprecated public boolean isLoginException(Throwable throwable) { - return this.isLoginException(throwable, this.targetDriverDialect); + return this.isLoginException(throwable, this.connectionContext.getDriverDialect()); } @Override @@ -695,7 +673,7 @@ public boolean isLoginException(final Throwable throwable, @Nullable TargetDrive if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(throwable, targetDriverDialect); } - return this.exceptionManager.isLoginException(this.dialect, throwable, targetDriverDialect); + return this.exceptionManager.isLoginException(this.connectionContext.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -703,31 +681,30 @@ public boolean isLoginException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(sqlState); } - return this.exceptionManager.isLoginException(this.dialect, sqlState); + return this.exceptionManager.isLoginException(this.connectionContext.getDbDialect(), sqlState); } @Override public Dialect getDialect() { - return this.dialect; + return this.connectionContext.getDbDialect(); } @Override public TargetDriverDialect getTargetDriverDialect() { - return this.targetDriverDialect; + return this.connectionContext.getDriverDialect(); } public void updateDialect(final @NonNull Connection connection) throws SQLException { - final Dialect originalDialect = this.dialect; - this.dialect = this.dialectProvider.getDialect( - this.originalUrl, - this.initialConnectionHostSpec, - connection); - if (originalDialect == this.dialect) { + final Dialect originalDialect = this.connectionContext.getDbDialect(); + Dialect dialect = this.dialectProvider.getDialect( + this.connectionContext.getProtocol(), this.initialConnectionHostSpec, connection); + if (originalDialect == this.connectionContext.getDbDialect()) { return; } - final HostListProviderSupplier supplier = this.dialect.getHostListProvider(); - this.setHostListProvider(supplier.getProvider(this.props, this.originalUrl, this.servicesContainer)); + this.connectionContext.setDbDialect(dialect); + final HostListProviderSupplier supplier = this.connectionContext.getDbDialect().getHostListProvider(); + this.setHostListProvider(supplier.getProvider(this.connectionContext, this.servicesContainer)); this.refreshHostList(connection); } @@ -770,12 +747,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.props)); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getProps())); } @Override public Properties getProperties() { - return this.props; + return this.connectionContext.getProps(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java index d29a1b3dd..aa24188fe 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java @@ -21,7 +21,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; @@ -36,6 +35,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.CacheMap; public class DialectManager implements DialectProvider { @@ -119,12 +119,7 @@ public static void resetEndpointCache() { } @Override - public Dialect getDialect( - final @NonNull String driverProtocol, - final @NonNull String url, - final @NonNull Properties props) - throws SQLException { - + public Dialect getDialect(final @NonNull ConnectionContext connectionContext) throws SQLException { this.canUpdate = false; this.dialect = null; @@ -136,10 +131,10 @@ public Dialect getDialect( return this.dialect; } - final String userDialectSetting = DIALECT.getString(props); + final String userDialectSetting = DIALECT.getString(connectionContext.getProps()); final String dialectCode = !StringUtils.isNullOrEmpty(userDialectSetting) ? userDialectSetting - : knownEndpointDialects.get(url); + : knownEndpointDialects.get(connectionContext.getUrl()); if (!StringUtils.isNullOrEmpty(dialectCode)) { final Dialect userDialect = knownDialectsByCode.get(dialectCode); @@ -154,18 +149,18 @@ public Dialect getDialect( } } - if (StringUtils.isNullOrEmpty(driverProtocol)) { + if (StringUtils.isNullOrEmpty(connectionContext.getProtocol())) { throw new IllegalArgumentException("protocol"); } - String host = url; + String host = connectionContext.getUrl(); final List hosts = this.connectionUrlParser.getHostsFromConnectionUrl( - url, true, pluginService::getHostSpecBuilder); + connectionContext.getUrl(), true, pluginService::getHostSpecBuilder); if (!Utils.isNullOrEmpty(hosts)) { host = hosts.get(0).getHost(); } - if (driverProtocol.contains("mysql")) { + if (connectionContext.getProtocol().contains("mysql")) { RdsUrlType type = this.rdsHelper.identifyRdsType(host); if (type.isRdsCluster()) { this.canUpdate = true; @@ -187,7 +182,7 @@ public Dialect getDialect( return this.dialect; } - if (driverProtocol.contains("postgresql")) { + if (connectionContext.getProtocol().contains("postgresql")) { RdsUrlType type = this.rdsHelper.identifyRdsType(host); if (RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP.equals(type)) { this.canUpdate = false; @@ -215,7 +210,7 @@ public Dialect getDialect( return this.dialect; } - if (driverProtocol.contains("mariadb")) { + if (connectionContext.getProtocol().contains("mariadb")) { this.canUpdate = true; this.dialectCode = DialectCodes.MARIADB; this.dialect = knownDialectsByCode.get(DialectCodes.MARIADB); diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java index bc90e7fdc..ed0f4cae7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java @@ -21,12 +21,10 @@ import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.util.connection.ConnectionContext; public interface DialectProvider { - Dialect getDialect( - final @NonNull String driverProtocol, - final @NonNull String url, - final @NonNull Properties props) throws SQLException; + Dialect getDialect(final @NonNull ConnectionContext connectionContext) throws SQLException; Dialect getDialect( final @NonNull String originalUrl, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java index 0dfe44dc5..4515d6285 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java @@ -16,15 +16,14 @@ package software.amazon.jdbc.dialect; -import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostListProvider; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; @FunctionalInterface public interface HostListProviderSupplier { @NonNull HostListProvider getProvider( - final @NonNull Properties properties, - final String initialUrl, + final @NonNull ConnectionContext connectionContext, final @NonNull FullServicesContainer servicesContainer); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java index de9f181d3..26043aa41 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java @@ -105,8 +105,8 @@ public List getDialectUpdateCandidates() { } public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> - new ConnectionStringHostListProvider(properties, initialUrl, servicesContainer.getHostListProviderService()); + return (connectionContext, servicesContainer) -> + new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java index 80f55bdad..f088ef864 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java @@ -25,22 +25,23 @@ import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.util.ConnectionUrlParser; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.connection.ConnectionContext; public class ConnectionStringHostListProvider implements StaticHostListProvider { private static final Logger LOGGER = Logger.getLogger(ConnectionStringHostListProvider.class.getName()); final List hostList = new ArrayList<>(); - Properties properties; private boolean isInitialized = false; private final boolean isSingleWriterConnectionString; private final ConnectionUrlParser connectionUrlParser; - private final String initialUrl; + private final ConnectionContext connectionContext; private final HostListProviderService hostListProviderService; public static final AwsWrapperProperty SINGLE_WRITER_CONNECTION_STRING = @@ -51,20 +52,17 @@ public class ConnectionStringHostListProvider implements StaticHostListProvider + "cluster has only one writer. The writer must be the first host in the connection string"); public ConnectionStringHostListProvider( - final @NonNull Properties properties, - final String initialUrl, + final @NonNull ConnectionContext connectionContext, final @NonNull HostListProviderService hostListProviderService) { - this(properties, initialUrl, hostListProviderService, new ConnectionUrlParser()); + this(connectionContext, hostListProviderService, new ConnectionUrlParser()); } ConnectionStringHostListProvider( - final @NonNull Properties properties, - final String initialUrl, + final @NonNull ConnectionContext connectionContext, final @NonNull HostListProviderService hostListProviderService, final @NonNull ConnectionUrlParser connectionUrlParser) { - - this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(properties); - this.initialUrl = initialUrl; + this.connectionContext = connectionContext; + this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionContext.getProps()); this.connectionUrlParser = connectionUrlParser; this.hostListProviderService = hostListProviderService; } @@ -74,11 +72,13 @@ private void init() throws SQLException { return; } this.hostList.addAll( - this.connectionUrlParser.getHostsFromConnectionUrl(this.initialUrl, this.isSingleWriterConnectionString, + this.connectionUrlParser.getHostsFromConnectionUrl( + this.connectionContext.getUrl(), + this.isSingleWriterConnectionString, () -> this.hostListProviderService.getHostSpecBuilder())); if (this.hostList.isEmpty()) { throw new SQLException(Messages.get("ConnectionStringHostListProvider.parsedListEmpty", - new Object[] {this.initialUrl})); + new Object[] {this.connectionContext.getUrl()})); } this.hostListProviderService.setInitialConnectionHostSpec(this.hostList.get(0)); this.isInitialized = true; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java new file mode 100644 index 000000000..20f1c6669 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.util.connection; + +import java.util.Properties; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.ConnectionUrlParser; +import software.amazon.jdbc.util.PropertyUtils; + +public class ConnectionContext { + protected final static ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); + protected final String url; + protected final String protocol; + protected final TargetDriverDialect driverDialect; + protected final Properties props; + protected Dialect dbDialect; + + public ConnectionContext(String url, TargetDriverDialect driverDialect, Properties props) { + this.url = url; + this.protocol = connectionUrlParser.getProtocol(url); + this.driverDialect = driverDialect; + this.props = props; + } + + public String getUrl() { + return url; + } + + public String getProtocol() { + return protocol; + } + + public TargetDriverDialect getDriverDialect() { + return driverDialect; + } + + public Properties getProps() { + return PropertyUtils.copyProperties(props); + } + + public Dialect getDbDialect() { + return dbDialect; + } + + public void setDbDialect(Dialect dbDialect) { + this.dbDialect = dbDialect; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java index 90b47ac88..714e2700a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java @@ -58,6 +58,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -66,20 +67,17 @@ public class ConnectionWrapper implements Connection, CanReleaseResources { private static final Logger LOGGER = Logger.getLogger(ConnectionWrapper.class.getName()); + protected ConnectionContext connectionContext; protected ConnectionPluginManager pluginManager; protected TelemetryFactory telemetryFactory; protected PluginService pluginService; protected HostListProviderService hostListProviderService; protected PluginManagerService pluginManagerService; - protected String targetDriverProtocol; - protected String originalUrl; protected @Nullable ConfigurationProfile configurationProfile; protected @Nullable Throwable openConnectionStacktrace; - protected final ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); - public ConnectionWrapper( @NonNull final FullServicesContainer servicesContainer, @NonNull final Properties props, @@ -94,8 +92,7 @@ public ConnectionWrapper( throw new IllegalArgumentException("url"); } - this.originalUrl = url; - this.targetDriverProtocol = connectionUrlParser.getProtocol(url); + this.connectionContext = new ConnectionContext(url, driverDialect, props); this.configurationProfile = configurationProfile; final ConnectionPluginManager pluginManager = @@ -105,18 +102,14 @@ public ConnectionWrapper( this, servicesContainer.getTelemetryFactory()); servicesContainer.setConnectionPluginManager(pluginManager); - final PluginServiceImpl pluginService = new PluginServiceImpl( - servicesContainer, - props, - url, - this.targetDriverProtocol, - driverDialect, - this.configurationProfile); + + final PluginServiceImpl pluginService = + new PluginServiceImpl(servicesContainer, this.connectionContext, this.configurationProfile); servicesContainer.setHostListProviderService(pluginService); servicesContainer.setPluginService(pluginService); servicesContainer.setPluginManagerService(pluginService); - init(props, servicesContainer, defaultConnectionProvider, driverDialect); + init(props, servicesContainer); if (PropertyDefinition.LOG_UNCLOSED_CONNECTIONS.getBoolean(props)) { this.openConnectionStacktrace = new Throwable(Messages.get("ConnectionWrapper.unclosedConnectionInstantiated")); @@ -153,13 +146,10 @@ protected ConnectionWrapper( pluginManagerService ); - init(props, servicesContainer, defaultConnectionProvider, driverDialect); + init(props, servicesContainer); } - protected void init(final Properties props, - final FullServicesContainer servicesContainer, - final ConnectionProvider defaultConnectionProvider, - final TargetDriverDialect driverDialect) throws SQLException { + protected void init(final Properties props, final FullServicesContainer servicesContainer) throws SQLException { this.pluginManager = servicesContainer.getConnectionPluginManager(); this.telemetryFactory = servicesContainer.getTelemetryFactory(); this.pluginService = servicesContainer.getPluginService(); @@ -170,19 +160,16 @@ protected void init(final Properties props, final HostListProviderSupplier supplier = this.pluginService.getDialect().getHostListProvider(); if (supplier != null) { - final HostListProvider provider = supplier.getProvider(props, this.originalUrl, servicesContainer); + final HostListProvider provider = supplier.getProvider(this.connectionContext, servicesContainer); hostListProviderService.setHostListProvider(provider); } - this.pluginManager.initHostProvider( - this.targetDriverProtocol, this.originalUrl, props, this.hostListProviderService); - + this.pluginManager.initHostProvider(this.connectionContext, this.hostListProviderService); this.pluginService.refreshHostList(); - if (this.pluginService.getCurrentConnection() == null) { final Connection conn = this.pluginManager.connect( - this.targetDriverProtocol, + this.connectionContext.getProtocol(), this.pluginService.getInitialConnectionHostSpec(), props, true, diff --git a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java index 6e5844ccf..05f700394 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java @@ -1,247 +1,247 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.zaxxer.hikari.HikariConfig; -import com.zaxxer.hikari.HikariDataSource; -import com.zaxxer.hikari.HikariPoolMXBean; -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.targetdriverdialect.ConnectInfo; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.Pair; -import software.amazon.jdbc.util.storage.SlidingExpirationCache; - -class HikariPooledConnectionProviderTest { - @Mock Connection mockConnection; - @Mock HikariDataSource mockDataSource; - @Mock HostSpec mockHostSpec; - @Mock HikariConfig mockConfig; - @Mock Dialect mockDialect; - @Mock TargetDriverDialect mockTargetDriverDialect; - @Mock HikariDataSource dsWithNoConnections; - @Mock HikariDataSource dsWith1Connection; - @Mock HikariDataSource dsWith2Connections; - @Mock HikariPoolMXBean mxBeanWithNoConnections; - @Mock HikariPoolMXBean mxBeanWith1Connection; - @Mock HikariPoolMXBean mxBeanWith2Connections; - private static final String LEAST_CONNECTIONS = "leastConnections"; - private final int port = 5432; - private final String user1 = "user1"; - private final String user2 = "user2"; - private final String password = "password"; - private final String db = "mydb"; - private final String writerUrlNoConnections = "writerWithNoConnections.XYZ.us-east-1.rds.amazonaws.com"; - private final HostSpec writerHostNoConnections = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host(writerUrlNoConnections).port(port).role(HostRole.WRITER).build(); - private final String readerUrl1Connection = "readerWith1connection.XYZ.us-east-1.rds.amazonaws.com"; - private final HostSpec readerHost1Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host(readerUrl1Connection).port(port).role(HostRole.READER).build(); - private final String readerUrl2Connection = "readerWith2connection.XYZ.us-east-1.rds.amazonaws.com"; - private final HostSpec readerHost2Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host(readerUrl2Connection).port(port).role(HostRole.READER).build(); - private final String protocol = "protocol://"; - - private final Properties defaultProps = getDefaultProps(); - private final List testHosts = getTestHosts(); - private HikariPooledConnectionProvider provider; - - private AutoCloseable closeable; - - private List getTestHosts() { - List hosts = new ArrayList<>(); - hosts.add(writerHostNoConnections); - hosts.add(readerHost1Connection); - hosts.add(readerHost2Connection); - return hosts; - } - - private Properties getDefaultProps() { - Properties props = new Properties(); - props.setProperty(PropertyDefinition.USER.name, user1); - props.setProperty(PropertyDefinition.PASSWORD.name, password); - props.setProperty(PropertyDefinition.DATABASE.name, db); - return props; - } - - @BeforeEach - void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - when(mockDataSource.getConnection()).thenReturn(mockConnection); - when(mockConnection.isValid(any(Integer.class))).thenReturn(true); - when(dsWithNoConnections.getHikariPoolMXBean()).thenReturn(mxBeanWithNoConnections); - when(mxBeanWithNoConnections.getActiveConnections()).thenReturn(0); - when(dsWith1Connection.getHikariPoolMXBean()).thenReturn(mxBeanWith1Connection); - when(mxBeanWith1Connection.getActiveConnections()).thenReturn(1); - when(dsWith2Connections.getHikariPoolMXBean()).thenReturn(mxBeanWith2Connections); - when(mxBeanWith2Connections.getActiveConnections()).thenReturn(2); - } - - @AfterEach - void tearDown() throws Exception { - if (provider != null) { - provider.releaseResources(); - } - closeable.close(); - } - - @Test - void testConnectWithDefaultMapping() throws SQLException { - when(mockHostSpec.getUrl()).thenReturn("url"); - final Set expectedUrls = new HashSet<>(Collections.singletonList("url")); - final Set expectedKeys = new HashSet<>( - Collections.singletonList(Pair.create("url", user1))); - - provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); - - doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); - doReturn(new ConnectInfo("url", new Properties())) - .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); - - Properties props = new Properties(); - props.setProperty(PropertyDefinition.USER.name, user1); - props.setProperty(PropertyDefinition.PASSWORD.name, password); - try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { - assertEquals(mockConnection, conn); - assertEquals(1, provider.getHostCount()); - final Set hosts = provider.getHosts(); - assertEquals(expectedUrls, hosts); - final Set keys = provider.getKeys(); - assertEquals(expectedKeys, keys); - } - } - - @Test - void testConnectWithCustomMapping() throws SQLException { - when(mockHostSpec.getUrl()).thenReturn("url"); - final Set expectedKeys = new HashSet<>( - Collections.singletonList(Pair.create("url", "url+someUniqueKey"))); - - provider = spy(new HikariPooledConnectionProvider( - (hostSpec, properties) -> mockConfig, - (hostSpec, properties) -> hostSpec.getUrl() + "+someUniqueKey")); - - doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); - - Properties props = new Properties(); - props.setProperty(PropertyDefinition.USER.name, user1); - props.setProperty(PropertyDefinition.PASSWORD.name, password); - try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { - assertEquals(mockConnection, conn); - assertEquals(1, provider.getHostCount()); - final Set keys = provider.getKeys(); - assertEquals(expectedKeys, keys); - } - } - - @Test - public void testAcceptsUrl() { - final String clusterUrl = "my-database.cluster-XYZ.us-east-1.rds.amazonaws.com"; - provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); - - assertTrue( - provider.acceptsUrl(protocol, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(readerUrl2Connection).build(), - defaultProps)); - assertFalse( - provider.acceptsUrl(protocol, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(clusterUrl).build(), defaultProps)); - } - - @Test - public void testRandomStrategy() throws SQLException { - provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); - provider.setDatabasePools(getTestPoolMap()); - - HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, "random", defaultProps); - assertTrue(readerUrl1Connection.equals(selectedHost.getHost()) - || readerUrl2Connection.equals(selectedHost.getHost())); - } - - @Test - public void testLeastConnectionsStrategy() throws SQLException { - provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); - provider.setDatabasePools(getTestPoolMap()); - - HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, LEAST_CONNECTIONS, defaultProps); - // Other reader has 2 connections - assertEquals(readerUrl1Connection, selectedHost.getHost()); - } - - private SlidingExpirationCache getTestPoolMap() { - SlidingExpirationCache map = new SlidingExpirationCache<>(); - map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user1), - (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); - map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user2), - (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); - map.computeIfAbsent(Pair.create(readerHost1Connection.getUrl(), user1), - (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); - return map; - } - - @Test - public void testConfigurePool() throws SQLException { - provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); - final String expectedJdbcUrl = - protocol + readerHost1Connection.getUrl() + db + "?database=" + db; - doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) - .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); - - provider.configurePool(mockConfig, protocol, readerHost1Connection, defaultProps, mockTargetDriverDialect); - verify(mockConfig).setJdbcUrl(expectedJdbcUrl); - verify(mockConfig).setUsername(user1); - verify(mockConfig).setPassword(password); - } - - @Test - public void testConnectToDeletedInstance() throws SQLException { - provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); - - doReturn(mockDataSource).when(provider) - .createHikariDataSource(eq(protocol), eq(readerHost1Connection), eq(defaultProps), eq(mockTargetDriverDialect)); - when(mockDataSource.getConnection()).thenThrow(SQLException.class); - - assertThrows(SQLException.class, - () -> provider.connect(protocol, mockDialect, mockTargetDriverDialect, readerHost1Connection, defaultProps)); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import com.zaxxer.hikari.HikariConfig; +// import com.zaxxer.hikari.HikariDataSource; +// import com.zaxxer.hikari.HikariPoolMXBean; +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Collections; +// import java.util.HashSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.Set; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.targetdriverdialect.ConnectInfo; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.Pair; +// import software.amazon.jdbc.util.storage.SlidingExpirationCache; +// +// class HikariPooledConnectionProviderTest { +// @Mock Connection mockConnection; +// @Mock HikariDataSource mockDataSource; +// @Mock HostSpec mockHostSpec; +// @Mock HikariConfig mockConfig; +// @Mock Dialect mockDialect; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// @Mock HikariDataSource dsWithNoConnections; +// @Mock HikariDataSource dsWith1Connection; +// @Mock HikariDataSource dsWith2Connections; +// @Mock HikariPoolMXBean mxBeanWithNoConnections; +// @Mock HikariPoolMXBean mxBeanWith1Connection; +// @Mock HikariPoolMXBean mxBeanWith2Connections; +// private static final String LEAST_CONNECTIONS = "leastConnections"; +// private final int port = 5432; +// private final String user1 = "user1"; +// private final String user2 = "user2"; +// private final String password = "password"; +// private final String db = "mydb"; +// private final String writerUrlNoConnections = "writerWithNoConnections.XYZ.us-east-1.rds.amazonaws.com"; +// private final HostSpec writerHostNoConnections = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host(writerUrlNoConnections).port(port).role(HostRole.WRITER).build(); +// private final String readerUrl1Connection = "readerWith1connection.XYZ.us-east-1.rds.amazonaws.com"; +// private final HostSpec readerHost1Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host(readerUrl1Connection).port(port).role(HostRole.READER).build(); +// private final String readerUrl2Connection = "readerWith2connection.XYZ.us-east-1.rds.amazonaws.com"; +// private final HostSpec readerHost2Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host(readerUrl2Connection).port(port).role(HostRole.READER).build(); +// private final String protocol = "protocol://"; +// +// private final Properties defaultProps = getDefaultProps(); +// private final List testHosts = getTestHosts(); +// private HikariPooledConnectionProvider provider; +// +// private AutoCloseable closeable; +// +// private List getTestHosts() { +// List hosts = new ArrayList<>(); +// hosts.add(writerHostNoConnections); +// hosts.add(readerHost1Connection); +// hosts.add(readerHost2Connection); +// return hosts; +// } +// +// private Properties getDefaultProps() { +// Properties props = new Properties(); +// props.setProperty(PropertyDefinition.USER.name, user1); +// props.setProperty(PropertyDefinition.PASSWORD.name, password); +// props.setProperty(PropertyDefinition.DATABASE.name, db); +// return props; +// } +// +// @BeforeEach +// void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockDataSource.getConnection()).thenReturn(mockConnection); +// when(mockConnection.isValid(any(Integer.class))).thenReturn(true); +// when(dsWithNoConnections.getHikariPoolMXBean()).thenReturn(mxBeanWithNoConnections); +// when(mxBeanWithNoConnections.getActiveConnections()).thenReturn(0); +// when(dsWith1Connection.getHikariPoolMXBean()).thenReturn(mxBeanWith1Connection); +// when(mxBeanWith1Connection.getActiveConnections()).thenReturn(1); +// when(dsWith2Connections.getHikariPoolMXBean()).thenReturn(mxBeanWith2Connections); +// when(mxBeanWith2Connections.getActiveConnections()).thenReturn(2); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// if (provider != null) { +// provider.releaseResources(); +// } +// closeable.close(); +// } +// +// @Test +// void testConnectWithDefaultMapping() throws SQLException { +// when(mockHostSpec.getUrl()).thenReturn("url"); +// final Set expectedUrls = new HashSet<>(Collections.singletonList("url")); +// final Set expectedKeys = new HashSet<>( +// Collections.singletonList(Pair.create("url", user1))); +// +// provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); +// +// doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); +// doReturn(new ConnectInfo("url", new Properties())) +// .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); +// +// Properties props = new Properties(); +// props.setProperty(PropertyDefinition.USER.name, user1); +// props.setProperty(PropertyDefinition.PASSWORD.name, password); +// try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { +// assertEquals(mockConnection, conn); +// assertEquals(1, provider.getHostCount()); +// final Set hosts = provider.getHosts(); +// assertEquals(expectedUrls, hosts); +// final Set keys = provider.getKeys(); +// assertEquals(expectedKeys, keys); +// } +// } +// +// @Test +// void testConnectWithCustomMapping() throws SQLException { +// when(mockHostSpec.getUrl()).thenReturn("url"); +// final Set expectedKeys = new HashSet<>( +// Collections.singletonList(Pair.create("url", "url+someUniqueKey"))); +// +// provider = spy(new HikariPooledConnectionProvider( +// (hostSpec, properties) -> mockConfig, +// (hostSpec, properties) -> hostSpec.getUrl() + "+someUniqueKey")); +// +// doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); +// +// Properties props = new Properties(); +// props.setProperty(PropertyDefinition.USER.name, user1); +// props.setProperty(PropertyDefinition.PASSWORD.name, password); +// try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { +// assertEquals(mockConnection, conn); +// assertEquals(1, provider.getHostCount()); +// final Set keys = provider.getKeys(); +// assertEquals(expectedKeys, keys); +// } +// } +// +// @Test +// public void testAcceptsUrl() { +// final String clusterUrl = "my-database.cluster-XYZ.us-east-1.rds.amazonaws.com"; +// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); +// +// assertTrue( +// provider.acceptsUrl(protocol, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(readerUrl2Connection).build(), +// defaultProps)); +// assertFalse( +// provider.acceptsUrl(protocol, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(clusterUrl).build(), defaultProps)); +// } +// +// @Test +// public void testRandomStrategy() throws SQLException { +// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); +// provider.setDatabasePools(getTestPoolMap()); +// +// HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, "random", defaultProps); +// assertTrue(readerUrl1Connection.equals(selectedHost.getHost()) +// || readerUrl2Connection.equals(selectedHost.getHost())); +// } +// +// @Test +// public void testLeastConnectionsStrategy() throws SQLException { +// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); +// provider.setDatabasePools(getTestPoolMap()); +// +// HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, LEAST_CONNECTIONS, defaultProps); +// // Other reader has 2 connections +// assertEquals(readerUrl1Connection, selectedHost.getHost()); +// } +// +// private SlidingExpirationCache getTestPoolMap() { +// SlidingExpirationCache map = new SlidingExpirationCache<>(); +// map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user1), +// (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); +// map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user2), +// (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); +// map.computeIfAbsent(Pair.create(readerHost1Connection.getUrl(), user1), +// (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); +// return map; +// } +// +// @Test +// public void testConfigurePool() throws SQLException { +// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); +// final String expectedJdbcUrl = +// protocol + readerHost1Connection.getUrl() + db + "?database=" + db; +// doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) +// .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); +// +// provider.configurePool(mockConfig, protocol, readerHost1Connection, defaultProps, mockTargetDriverDialect); +// verify(mockConfig).setJdbcUrl(expectedJdbcUrl); +// verify(mockConfig).setUsername(user1); +// verify(mockConfig).setPassword(password); +// } +// +// @Test +// public void testConnectToDeletedInstance() throws SQLException { +// provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); +// +// doReturn(mockDataSource).when(provider) +// .createHikariDataSource(eq(protocol), eq(readerHost1Connection), eq(defaultProps), eq(mockTargetDriverDialect)); +// when(mockDataSource.getConnection()).thenThrow(SQLException.class); +// +// assertThrows(SQLException.class, +// () -> provider.connect(protocol, mockDialect, mockTargetDriverDialect, readerHost1Connection, defaultProps)); +// } +// } From be51cc70b71b4a49ecda45a78d18be252d07b970 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 17 Sep 2025 16:59:01 -0700 Subject: [PATCH 43/54] compiles with tests commented out --- .../ConnectionPluginManagerBenchmarks.java | 562 +++-- .../jdbc/benchmarks/PluginBenchmarks.java | 694 +++--- .../testplugin/BenchmarkPlugin.java | 16 +- docs/development-guide/LoadablePlugins.md | 18 +- .../jdbc/C3P0PooledConnectionProvider.java | 25 +- .../amazon/jdbc/ConnectionPlugin.java | 15 +- .../amazon/jdbc/ConnectionPluginManager.java | 22 +- .../amazon/jdbc/ConnectionProvider.java | 16 +- .../jdbc/ConnectionProviderManager.java | 12 +- .../jdbc/DataSourceConnectionProvider.java | 33 +- .../amazon/jdbc/DriverConnectionProvider.java | 31 +- .../jdbc/HikariPooledConnectionProvider.java | 64 +- .../amazon/jdbc/HikariPoolsHolder.java | 2 +- .../jdbc/LeastConnectionsHostSelector.java | 10 +- .../amazon/jdbc/PartialPluginService.java | 75 +- .../amazon/jdbc/PluginServiceImpl.java | 12 +- .../jdbc/dialect/AuroraMysqlDialect.java | 9 +- .../amazon/jdbc/dialect/AuroraPgDialect.java | 9 +- .../amazon/jdbc/dialect/DialectManager.java | 2 +- .../amazon/jdbc/dialect/MariaDbDialect.java | 4 +- .../amazon/jdbc/dialect/PgDialect.java | 4 +- .../RdsMultiAzDbClusterMysqlDialect.java | 8 +- .../dialect/RdsMultiAzDbClusterPgDialect.java | 8 +- .../amazon/jdbc/dialect/UnknownDialect.java | 4 +- .../AuroraHostListProvider.java | 9 +- .../ConnectionStringHostListProvider.java | 4 +- .../hostlistprovider/RdsHostListProvider.java | 17 +- .../RdsMultiAzDbClusterListProvider.java | 9 +- .../ClusterTopologyMonitorImpl.java | 6 +- .../MonitoringRdsHostListProvider.java | 19 +- .../MonitoringRdsMultiAzHostListProvider.java | 18 +- .../MultiAzClusterTopologyMonitorImpl.java | 2 - .../jdbc/plugin/AbstractConnectionPlugin.java | 12 +- .../plugin/AuroraConnectionTrackerPlugin.java | 8 +- ...AuroraInitialConnectionStrategyPlugin.java | 17 +- .../AwsSecretsManagerConnectionPlugin.java | 11 +- .../plugin/ConnectTimeConnectionPlugin.java | 17 +- .../jdbc/plugin/DefaultConnectionPlugin.java | 61 +- .../bluegreen/BlueGreenConnectionPlugin.java | 4 +- .../bluegreen/BlueGreenStatusMonitor.java | 18 +- .../customendpoint/CustomEndpointPlugin.java | 13 +- .../plugin/dev/DeveloperConnectionPlugin.java | 27 +- .../ExceptionSimulatorConnectCallback.java | 5 +- .../efm/HostMonitoringConnectionPlugin.java | 9 +- .../plugin/efm2/HostMonitorServiceImpl.java | 6 +- .../efm2/HostMonitoringConnectionPlugin.java | 9 +- .../ClusterAwareReaderFailoverHandler.java | 6 +- .../ClusterAwareWriterFailoverHandler.java | 6 +- .../failover/FailoverConnectionPlugin.java | 16 +- .../failover2/FailoverConnectionPlugin.java | 22 +- .../federatedauth/FederatedAuthPlugin.java | 18 +- .../plugin/federatedauth/OktaAuthPlugin.java | 19 +- .../plugin/iam/IamAuthConnectionPlugin.java | 19 +- .../limitless/LimitlessConnectionPlugin.java | 9 +- .../limitless/LimitlessRouterServiceImpl.java | 22 +- .../ReadWriteSplittingPlugin.java | 12 +- .../plugin/staledns/AuroraStaleDnsHelper.java | 7 +- .../plugin/staledns/AuroraStaleDnsPlugin.java | 15 +- .../FastestResponseStrategyPlugin.java | 8 +- .../HostResponseTimeServiceImpl.java | 6 +- .../amazon/jdbc/util/ServiceUtility.java | 21 +- .../util/connection/ConnectionContext.java | 10 +- .../connection/ConnectionServiceImpl.java | 28 +- .../jdbc/util/monitoring/MonitorService.java | 16 +- .../util/monitoring/MonitorServiceImpl.java | 29 +- .../jdbc/wrapper/ConnectionWrapper.java | 11 +- .../aurora/TestAuroraHostListProvider.java | 66 +- .../aurora/TestPluginServiceImpl.java | 12 +- .../tests/AdvancedPerformanceTest.java | 1534 ++++++------- .../jdbc/ConnectionPluginManagerTests.java | 1940 ++++++++--------- .../amazon/jdbc/DialectDetectionTests.java | 578 ++--- .../amazon/jdbc/PluginServiceImplTests.java | 1892 ++++++++-------- .../RdsHostListProviderTest.java | 1258 +++++------ .../RdsMultiAzDbClusterListProviderTest.java | 940 ++++---- .../amazon/jdbc/mock/TestPluginOne.java | 311 ++- .../amazon/jdbc/mock/TestPluginThree.java | 175 +- .../jdbc/mock/TestPluginThrowException.java | 225 +- .../amazon/jdbc/mock/TestPluginTwo.java | 66 +- .../AuroraConnectionTrackerPluginTest.java | 490 ++--- ...AwsSecretsManagerConnectionPluginTest.java | 1032 ++++----- .../plugin/DefaultConnectionPluginTest.java | 276 +-- .../CustomEndpointPluginTest.java | 318 +-- .../dev/DeveloperConnectionPluginTest.java | 712 +++--- .../HostMonitoringConnectionPluginTest.java | 688 +++--- .../FederatedAuthPluginTest.java | 450 ++-- .../federatedauth/OktaAuthPluginTest.java | 440 ++-- .../iam/IamAuthConnectionPluginTest.java | 590 ++--- .../LimitlessConnectionPluginTest.java | 324 +-- .../ReadWriteSplittingPluginTest.java | 1252 +++++------ .../monitoring/MonitorServiceImplTest.java | 630 +++--- 90 files changed, 9104 insertions(+), 9351 deletions(-) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java index e1078539a..773bd83ce 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java @@ -1,284 +1,278 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.benchmarks; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.when; -import static org.mockito.MockitoAnnotations.openMocks; - -import java.lang.annotation.Target; -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.mockito.Mock; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Level; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.OutputTimeUnit; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.TearDown; -import org.openjdk.jmh.annotations.Warmup; -import org.openjdk.jmh.profile.GCProfiler; -import org.openjdk.jmh.runner.Runner; -import org.openjdk.jmh.runner.RunnerException; -import org.openjdk.jmh.runner.options.Options; -import org.openjdk.jmh.runner.options.OptionsBuilder; -import software.amazon.jdbc.ConnectionPluginFactory; -import software.amazon.jdbc.ConnectionPluginManager; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.NodeChangeOptions; -import software.amazon.jdbc.OldConnectionSuggestedAction; -import software.amazon.jdbc.PluginManagerService; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.benchmarks.testplugin.BenchmarkPluginFactory; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.profile.ConfigurationProfile; -import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; -import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; -import software.amazon.jdbc.util.telemetry.GaugeCallable; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; -import software.amazon.jdbc.wrapper.ConnectionWrapper; - -@State(Scope.Benchmark) -@Fork(3) -@Warmup(iterations = 3) -@Measurement(iterations = 10) -@BenchmarkMode(Mode.SingleShotTime) -@OutputTimeUnit(TimeUnit.NANOSECONDS) -public class ConnectionPluginManagerBenchmarks { - - private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; - private static final String FIELD_SERVER_ID = "SERVER_ID"; - private static final String FIELD_SESSION_ID = "SESSION_ID"; - private static final String url = "protocol//url"; - private ConnectionContext pluginsContext; - private ConnectionContext noPluginsContext; - private ConnectionPluginManager pluginManager; - private ConnectionPluginManager pluginManagerWithNoPlugins; - - @Mock ConnectionProvider mockConnectionProvider; - @Mock ConnectionWrapper mockConnectionWrapper; - @Mock FullServicesContainer mockServicesContainer; - @Mock PluginService mockPluginService; - @Mock PluginManagerService mockPluginManagerService; - @Mock TargetDriverDialect mockDriverDialect; - @Mock TelemetryFactory mockTelemetryFactory; - @Mock HostListProviderService mockHostListProvider; - @Mock Connection mockConnection; - @Mock Statement mockStatement; - @Mock ResultSet mockResultSet; - @Mock TelemetryContext mockTelemetryContext; - @Mock TelemetryCounter mockTelemetryCounter; - @Mock TelemetryGauge mockTelemetryGauge; - ConfigurationProfile configurationProfile; - private AutoCloseable closeable; - - public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder() - .include(software.amazon.jdbc.benchmarks.PluginBenchmarks.class.getSimpleName()) - .addProfiler(GCProfiler.class) - .detectJvmArgs() - .build(); - - new Runner(opt).run(); - } - - @Setup(Level.Iteration) - public void setUpIteration() throws Exception { - closeable = openMocks(this); - - when(mockConnectionProvider.connect( - anyString(), - any(Dialect.class), - any(TargetDriverDialect.class), - any(HostSpec.class), - any(Properties.class))).thenReturn(mockConnection); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); - when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - when(mockConnection.createStatement()).thenReturn(mockStatement); - when(mockConnection.createStatement()).thenReturn(mockStatement); - when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true, true, false); - when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); - when(mockResultSet.getString(eq(FIELD_SERVER_ID))) - .thenReturn("myInstance1.domain.com", "myInstance2.domain.com", "myInstance3.domain.com"); - when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - - // Create a plugin chain with 10 custom test plugins. - final List> pluginFactories = new ArrayList<>( - Collections.nCopies(10, BenchmarkPluginFactory.class)); - - configurationProfile = ConfigurationProfileBuilder.get() - .withName("benchmark") - .withPluginFactories(pluginFactories) - .build(); - - Properties noPluginsProps = new Properties(); - noPluginsProps.setProperty(PropertyDefinition.PLUGINS.name, ""); - this.noPluginsContext = new ConnectionContext(url, mockDriverDialect, noPluginsProps); - - Properties pluginsProps = new Properties(); - pluginsProps.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); - pluginsProps.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); - this.pluginsContext = new ConnectionContext(url, mockDriverDialect, pluginsProps); - - TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(pluginsProps); - - pluginManager = new ConnectionPluginManager(mockConnectionProvider, - null, - mockConnectionWrapper, - telemetryFactory); - pluginManager.init(mockServicesContainer, pluginsProps, mockPluginManagerService, configurationProfile); - - pluginManagerWithNoPlugins = new ConnectionPluginManager(mockConnectionProvider, null, - mockConnectionWrapper, telemetryFactory); - pluginManagerWithNoPlugins.init(mockServicesContainer, noPluginsProps, mockPluginManagerService, null); - } - - @TearDown(Level.Iteration) - public void tearDownIteration() throws Exception { - closeable.close(); - } - - @Benchmark - public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws SQLException { - final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, - mockConnectionWrapper, mockTelemetryFactory); - manager.init(mockServicesContainer, this.noPluginsContext.getProps(), mockPluginManagerService, configurationProfile); - return manager; - } - - @Benchmark - public ConnectionPluginManager initConnectionPluginManagerWithPlugins() throws SQLException { - final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, - mockConnectionWrapper, mockTelemetryFactory); - manager.init(mockServicesContainer, this.pluginsContext.getProps(), mockPluginManagerService, configurationProfile); - return manager; - } - - @Benchmark - public Connection connectWithPlugins() throws SQLException { - return pluginManager.connect( - "driverProtocol", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), - this.pluginsContext.getProps(), - true, - null); - } - - @Benchmark - public Connection connectWithNoPlugins() throws SQLException { - return pluginManagerWithNoPlugins.connect( - "driverProtocol", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), - this.noPluginsContext.getProps(), - true, - null); - } - - @Benchmark - public Integer executeWithPlugins() { - return pluginManager.execute( - int.class, - RuntimeException.class, - mockStatement, - JdbcMethod.STATEMENT_EXECUTE, - () -> 1, - new Object[] {1} - ); - } - - @Benchmark - public Integer executeWithNoPlugins() { - return pluginManagerWithNoPlugins.execute( - int.class, - RuntimeException.class, - mockStatement, - JdbcMethod.STATEMENT_EXECUTE, - () -> 1, - new Object[] {1} - ); - } - - @Benchmark - public ConnectionPluginManager initHostProvidersWithPlugins() throws SQLException { - pluginManager.initHostProvider(this.pluginsContext, mockHostListProvider); - return pluginManager; - } - - @Benchmark - public ConnectionPluginManager initHostProvidersWithNoPlugins() throws SQLException { - pluginManagerWithNoPlugins.initHostProvider(this.noPluginsContext, mockHostListProvider); - return pluginManager; - } - - @Benchmark - public EnumSet notifyConnectionChangedWithPlugins() { - return pluginManager.notifyConnectionChanged( - EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), - null); - } - - @Benchmark - public EnumSet notifyConnectionChangedWithNoPlugins() { - return pluginManagerWithNoPlugins.notifyConnectionChanged( - EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), - null); - } - - @Benchmark - public ConnectionPluginManager releaseResourcesWithPlugins() { - pluginManager.releaseResources(); - return pluginManager; - } - - @Benchmark - public ConnectionPluginManager releaseResourcesWithNoPlugins() { - pluginManagerWithNoPlugins.releaseResources(); - return pluginManager; - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.benchmarks; +// +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.when; +// import static org.mockito.MockitoAnnotations.openMocks; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.ArrayList; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.mockito.Mock; +// import org.openjdk.jmh.annotations.Benchmark; +// import org.openjdk.jmh.annotations.BenchmarkMode; +// import org.openjdk.jmh.annotations.Fork; +// import org.openjdk.jmh.annotations.Level; +// import org.openjdk.jmh.annotations.Measurement; +// import org.openjdk.jmh.annotations.Mode; +// import org.openjdk.jmh.annotations.OutputTimeUnit; +// import org.openjdk.jmh.annotations.Scope; +// import org.openjdk.jmh.annotations.Setup; +// import org.openjdk.jmh.annotations.State; +// import org.openjdk.jmh.annotations.TearDown; +// import org.openjdk.jmh.annotations.Warmup; +// import org.openjdk.jmh.profile.GCProfiler; +// import org.openjdk.jmh.runner.Runner; +// import org.openjdk.jmh.runner.RunnerException; +// import org.openjdk.jmh.runner.options.Options; +// import org.openjdk.jmh.runner.options.OptionsBuilder; +// import software.amazon.jdbc.ConnectionPluginFactory; +// import software.amazon.jdbc.ConnectionPluginManager; +// import software.amazon.jdbc.ConnectionProvider; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcMethod; +// import software.amazon.jdbc.NodeChangeOptions; +// import software.amazon.jdbc.OldConnectionSuggestedAction; +// import software.amazon.jdbc.PluginManagerService; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.benchmarks.testplugin.BenchmarkPluginFactory; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.profile.ConfigurationProfile; +// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.connection.ConnectionContext; +// import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; +// import software.amazon.jdbc.util.telemetry.GaugeCallable; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.util.telemetry.TelemetryGauge; +// import software.amazon.jdbc.wrapper.ConnectionWrapper; +// +// @State(Scope.Benchmark) +// @Fork(3) +// @Warmup(iterations = 3) +// @Measurement(iterations = 10) +// @BenchmarkMode(Mode.SingleShotTime) +// @OutputTimeUnit(TimeUnit.NANOSECONDS) +// public class ConnectionPluginManagerBenchmarks { +// +// private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; +// private static final String FIELD_SERVER_ID = "SERVER_ID"; +// private static final String FIELD_SESSION_ID = "SESSION_ID"; +// private static final String url = "protocol//url"; +// private ConnectionContext pluginsContext; +// private ConnectionContext noPluginsContext; +// private ConnectionPluginManager pluginManager; +// private ConnectionPluginManager pluginManagerWithNoPlugins; +// +// @Mock ConnectionProvider mockConnectionProvider; +// @Mock ConnectionWrapper mockConnectionWrapper; +// @Mock FullServicesContainer mockServicesContainer; +// @Mock PluginService mockPluginService; +// @Mock PluginManagerService mockPluginManagerService; +// @Mock TargetDriverDialect mockDriverDialect; +// @Mock TelemetryFactory mockTelemetryFactory; +// @Mock HostListProviderService mockHostListProvider; +// @Mock Connection mockConnection; +// @Mock Statement mockStatement; +// @Mock ResultSet mockResultSet; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock TelemetryCounter mockTelemetryCounter; +// @Mock TelemetryGauge mockTelemetryGauge; +// ConfigurationProfile configurationProfile; +// private AutoCloseable closeable; +// +// public static void main(String[] args) throws RunnerException { +// Options opt = new OptionsBuilder() +// .include(software.amazon.jdbc.benchmarks.PluginBenchmarks.class.getSimpleName()) +// .addProfiler(GCProfiler.class) +// .detectJvmArgs() +// .build(); +// +// new Runner(opt).run(); +// } +// +// @Setup(Level.Iteration) +// public void setUpIteration() throws Exception { +// closeable = openMocks(this); +// +// when(mockConnectionProvider.connect(any(), any(HostSpec.class))).thenReturn(mockConnection); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); +// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); +// when(mockConnection.createStatement()).thenReturn(mockStatement); +// when(mockConnection.createStatement()).thenReturn(mockStatement); +// when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); +// when(mockResultSet.next()).thenReturn(true, true, false); +// when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); +// when(mockResultSet.getString(eq(FIELD_SERVER_ID))) +// .thenReturn("myInstance1.domain.com", "myInstance2.domain.com", "myInstance3.domain.com"); +// when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// +// // Create a plugin chain with 10 custom test plugins. +// final List> pluginFactories = new ArrayList<>( +// Collections.nCopies(10, BenchmarkPluginFactory.class)); +// +// configurationProfile = ConfigurationProfileBuilder.get() +// .withName("benchmark") +// .withPluginFactories(pluginFactories) +// .build(); +// +// Properties noPluginsProps = new Properties(); +// noPluginsProps.setProperty(PropertyDefinition.PLUGINS.name, ""); +// this.noPluginsContext = new ConnectionContext(url, mockDriverDialect, noPluginsProps); +// +// Properties pluginsProps = new Properties(); +// pluginsProps.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); +// pluginsProps.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); +// this.pluginsContext = new ConnectionContext(url, mockDriverDialect, pluginsProps); +// +// TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(pluginsProps); +// +// pluginManager = new ConnectionPluginManager(mockConnectionProvider, +// null, +// mockConnectionWrapper, +// telemetryFactory); +// pluginManager.init(mockServicesContainer, pluginsProps, mockPluginManagerService, configurationProfile); +// +// pluginManagerWithNoPlugins = new ConnectionPluginManager(mockConnectionProvider, null, +// mockConnectionWrapper, telemetryFactory); +// pluginManagerWithNoPlugins.init(mockServicesContainer, noPluginsProps, mockPluginManagerService, null); +// } +// +// @TearDown(Level.Iteration) +// public void tearDownIteration() throws Exception { +// closeable.close(); +// } +// +// @Benchmark +// public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws SQLException { +// final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, +// mockConnectionWrapper, mockTelemetryFactory); +// manager.init(mockServicesContainer, this.noPluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); +// return manager; +// } +// +// @Benchmark +// public ConnectionPluginManager initConnectionPluginManagerWithPlugins() throws SQLException { +// final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, +// mockConnectionWrapper, mockTelemetryFactory); +// manager.init(mockServicesContainer, this.pluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); +// return manager; +// } +// +// @Benchmark +// public Connection connectWithPlugins() throws SQLException { +// return pluginManager.connect( +// "driverProtocol", +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), +// this.pluginsContext.getPropsCopy(), +// true, +// null); +// } +// +// @Benchmark +// public Connection connectWithNoPlugins() throws SQLException { +// return pluginManagerWithNoPlugins.connect( +// "driverProtocol", +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), +// this.noPluginsContext.getPropsCopy(), +// true, +// null); +// } +// +// @Benchmark +// public Integer executeWithPlugins() { +// return pluginManager.execute( +// int.class, +// RuntimeException.class, +// mockStatement, +// JdbcMethod.STATEMENT_EXECUTE, +// () -> 1, +// new Object[] {1} +// ); +// } +// +// @Benchmark +// public Integer executeWithNoPlugins() { +// return pluginManagerWithNoPlugins.execute( +// int.class, +// RuntimeException.class, +// mockStatement, +// JdbcMethod.STATEMENT_EXECUTE, +// () -> 1, +// new Object[] {1} +// ); +// } +// +// @Benchmark +// public ConnectionPluginManager initHostProvidersWithPlugins() throws SQLException { +// pluginManager.initHostProvider(this.pluginsContext, mockHostListProvider); +// return pluginManager; +// } +// +// @Benchmark +// public ConnectionPluginManager initHostProvidersWithNoPlugins() throws SQLException { +// pluginManagerWithNoPlugins.initHostProvider(this.noPluginsContext, mockHostListProvider); +// return pluginManager; +// } +// +// @Benchmark +// public EnumSet notifyConnectionChangedWithPlugins() { +// return pluginManager.notifyConnectionChanged( +// EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), +// null); +// } +// +// @Benchmark +// public EnumSet notifyConnectionChangedWithNoPlugins() { +// return pluginManagerWithNoPlugins.notifyConnectionChanged( +// EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), +// null); +// } +// +// @Benchmark +// public ConnectionPluginManager releaseResourcesWithPlugins() { +// pluginManager.releaseResources(); +// return pluginManager; +// } +// +// @Benchmark +// public ConnectionPluginManager releaseResourcesWithNoPlugins() { +// pluginManagerWithNoPlugins.releaseResources(); +// return pluginManager; +// } +// } diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java index 35932705c..36c31c61f 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java @@ -1,347 +1,347 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.benchmarks; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.when; - -import com.zaxxer.hikari.HikariConfig; -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Level; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.OutputTimeUnit; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.TearDown; -import org.openjdk.jmh.annotations.Warmup; -import org.openjdk.jmh.profile.GCProfiler; -import org.openjdk.jmh.runner.Runner; -import org.openjdk.jmh.runner.RunnerException; -import org.openjdk.jmh.runner.options.Options; -import org.openjdk.jmh.runner.options.OptionsBuilder; -import software.amazon.jdbc.ConnectionPluginManager; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.ConnectionProviderManager; -import software.amazon.jdbc.Driver; -import software.amazon.jdbc.HikariPooledConnectionProvider; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.PluginManagerService; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.benchmarks.testplugin.TestConnectionWrapper; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.monitoring.MonitorService; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.telemetry.GaugeCallable; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; -import software.amazon.jdbc.wrapper.ConnectionWrapper; - -@State(Scope.Benchmark) -@Fork(3) -@Warmup(iterations = 3) -@Measurement(iterations = 10) -@BenchmarkMode(Mode.SingleShotTime) -@OutputTimeUnit(TimeUnit.NANOSECONDS) -public class PluginBenchmarks { - - private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; - private static final String FIELD_SERVER_ID = "SERVER_ID"; - private static final String FIELD_SESSION_ID = "SESSION_ID"; - private static final String CONNECTION_STRING = "jdbc:postgresql://my.domain.com"; - private static final String PG_CONNECTION_STRING = - "jdbc:aws-wrapper:postgresql://instance-0.XYZ.us-east-2.rds.amazonaws.com"; - private static final String TEST_HOST = "instance-0"; - private static final int TEST_PORT = 5432; - private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host(TEST_HOST).port(TEST_PORT).build(); - - @Mock private StorageService mockStorageService; - @Mock private MonitorService mockMonitorService; - @Mock private PluginService mockPluginService; - @Mock private TargetDriverDialect mockTargetDriverDialect; - @Mock private Dialect mockDialect; - @Mock private ConnectionPluginManager mockConnectionPluginManager; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock TelemetryContext mockTelemetryContext; - @Mock TelemetryCounter mockTelemetryCounter; - @Mock TelemetryGauge mockTelemetryGauge; - @Mock private HostListProviderService mockHostListProviderService; - @Mock private PluginManagerService mockPluginManagerService; - @Mock ConnectionProvider mockConnectionProvider; - @Mock Connection mockConnection; - @Mock Statement mockStatement; - @Mock ResultSet mockResultSet; - private AutoCloseable closeable; - - public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder() - .include(PluginBenchmarks.class.getSimpleName()) - .addProfiler(GCProfiler.class) - .detectJvmArgs() - .build(); - - new Runner(opt).run(); - } - - @Setup(Level.Iteration) - public void setUpIteration() throws Exception { - closeable = MockitoAnnotations.openMocks(this); - when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean(), any())) - .thenReturn(mockConnection); - when(mockConnectionPluginManager.execute( - any(), any(), any(), eq(JdbcMethod.CONNECTION_CREATESTATEMENT), any(), any())) - .thenReturn(mockStatement); - when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); - // noinspection unchecked - when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - when(mockConnectionProvider.connect( - anyString(), - any(Dialect.class), - any(TargetDriverDialect.class), - any(HostSpec.class), - any(Properties.class))).thenReturn(mockConnection); - when(mockConnection.createStatement()).thenReturn(mockStatement); - when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true, true, false); - when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); - when(mockResultSet.getString(eq(FIELD_SERVER_ID))) - .thenReturn("instance-0", "instance-1"); - when(mockResultSet.getStatement()).thenReturn(mockStatement); - when(mockStatement.getConnection()).thenReturn(mockConnection); - when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); - when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); - when(this.mockPluginService.getDialect()).thenReturn(mockDialect); - } - - @TearDown(Level.Iteration) - public void tearDownIteration() throws Exception { - closeable.close(); - } - - @Benchmark - public void initAndReleaseBaseLine() { - } - - @Benchmark - public ConnectionWrapper initAndReleaseWithExecutionTimePlugin() throws SQLException { - try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING)) { - wrapper.releaseResources(); - return wrapper; - } - } - - private ConnectionWrapper getConnectionWrapper(Properties props, String connString) throws SQLException { - return new TestConnectionWrapper( - props, - connString, - mockConnectionProvider, - mockTargetDriverDialect, - mockConnectionPluginManager, - mockTelemetryFactory, - mockPluginService, - mockHostListProviderService, - mockPluginManagerService, - mockStorageService, - mockMonitorService); - } - - @Benchmark - public ConnectionWrapper initAndReleaseWithAuroraHostListPlugin() throws SQLException { - try (ConnectionWrapper wrapper = getConnectionWrapper(useAuroraHostListPlugin(), CONNECTION_STRING)) { - wrapper.releaseResources(); - return wrapper; - } - } - - @Benchmark - public ConnectionWrapper initAndReleaseWithExecutionTimeAndAuroraHostListPlugins() throws SQLException { - try (ConnectionWrapper wrapper = - getConnectionWrapper(useExecutionTimeAndAuroraHostListPlugins(), CONNECTION_STRING)) { - wrapper.releaseResources(); - return wrapper; - } - } - - @Benchmark - public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin() throws SQLException { - try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { - wrapper.releaseResources(); - return wrapper; - } - } - - @Benchmark - public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin() - throws SQLException { - try (ConnectionWrapper wrapper = - getConnectionWrapper(useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { - wrapper.releaseResources(); - return wrapper; - } - } - - @Benchmark - public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin_internalConnectionPools() throws SQLException { - HikariPooledConnectionProvider provider = - new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); - Driver.setCustomConnectionProvider(provider); - try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { - wrapper.releaseResources(); - ConnectionProviderManager.releaseResources(); - Driver.resetCustomConnectionProvider(); - return wrapper; - } - } - - @Benchmark - public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin_internalConnectionPools() - throws SQLException { - HikariPooledConnectionProvider provider = - new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); - Driver.setCustomConnectionProvider(provider); - try (ConnectionWrapper wrapper = getConnectionWrapper( - useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { - wrapper.releaseResources(); - ConnectionProviderManager.releaseResources(); - Driver.resetCustomConnectionProvider(); - return wrapper; - } - } - - @Benchmark - public Statement executeStatementBaseline() throws SQLException { - try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); - Statement statement = wrapper.createStatement()) { - return statement; - } - } - - @Benchmark - public ResultSet executeStatementWithExecutionTimePlugin() throws SQLException { - try ( - ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); - Statement statement = wrapper.createStatement(); - ResultSet resultSet = statement.executeQuery("some sql")) { - return resultSet; - } - } - - @Benchmark - public ResultSet executeStatementWithTelemetryDisabled() throws SQLException { - try ( - ConnectionWrapper wrapper = getConnectionWrapper(disabledTelemetry(), CONNECTION_STRING); - Statement statement = wrapper.createStatement(); - ResultSet resultSet = statement.executeQuery("some sql")) { - return resultSet; - } - } - - @Benchmark - public ResultSet executeStatementWithTelemetry() throws SQLException { - try ( - ConnectionWrapper wrapper = getConnectionWrapper(useTelemetry(), CONNECTION_STRING); - Statement statement = wrapper.createStatement(); - ResultSet resultSet = statement.executeQuery("some sql")) { - return resultSet; - } - } - - Properties useExecutionTimePlugin() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "executionTime"); - return properties; - } - - Properties useAuroraHostListPlugin() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "auroraHostList"); - return properties; - } - - Properties useExecutionTimeAndAuroraHostListPlugins() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "executionTime,auroraHostList"); - return properties; - } - - Properties useReadWriteSplittingPlugin() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "readWriteSplitting"); - return properties; - } - - Properties useAuroraHostListAndReadWriteSplittingPlugin() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); - return properties; - } - - Properties useReadWriteSplittingPluginWithReaderLoadBalancing() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "readWriteSplitting"); - properties.setProperty("loadBalanceReadOnlyTraffic", "true"); - return properties; - } - - Properties useAuroraHostListAndReadWriteSplittingPluginWithReaderLoadBalancing() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); - properties.setProperty("loadBalanceReadOnlyTraffic", "true"); - return properties; - } - - Properties useTelemetry() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); - properties.setProperty("enableTelemetry", "true"); - properties.setProperty("telemetryMetricsBackend", "none"); - properties.setProperty("telemetryTracesBackend", "none"); - return properties; - } - - Properties disabledTelemetry() { - final Properties properties = new Properties(); - properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); - properties.setProperty("enableTelemetry", "false"); - return properties; - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.benchmarks; +// +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyBoolean; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.when; +// +// import com.zaxxer.hikari.HikariConfig; +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import org.openjdk.jmh.annotations.Benchmark; +// import org.openjdk.jmh.annotations.BenchmarkMode; +// import org.openjdk.jmh.annotations.Fork; +// import org.openjdk.jmh.annotations.Level; +// import org.openjdk.jmh.annotations.Measurement; +// import org.openjdk.jmh.annotations.Mode; +// import org.openjdk.jmh.annotations.OutputTimeUnit; +// import org.openjdk.jmh.annotations.Scope; +// import org.openjdk.jmh.annotations.Setup; +// import org.openjdk.jmh.annotations.State; +// import org.openjdk.jmh.annotations.TearDown; +// import org.openjdk.jmh.annotations.Warmup; +// import org.openjdk.jmh.profile.GCProfiler; +// import org.openjdk.jmh.runner.Runner; +// import org.openjdk.jmh.runner.RunnerException; +// import org.openjdk.jmh.runner.options.Options; +// import org.openjdk.jmh.runner.options.OptionsBuilder; +// import software.amazon.jdbc.ConnectionPluginManager; +// import software.amazon.jdbc.ConnectionProvider; +// import software.amazon.jdbc.ConnectionProviderManager; +// import software.amazon.jdbc.Driver; +// import software.amazon.jdbc.HikariPooledConnectionProvider; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcMethod; +// import software.amazon.jdbc.PluginManagerService; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.benchmarks.testplugin.TestConnectionWrapper; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.monitoring.MonitorService; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.telemetry.GaugeCallable; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.util.telemetry.TelemetryGauge; +// import software.amazon.jdbc.wrapper.ConnectionWrapper; +// +// @State(Scope.Benchmark) +// @Fork(3) +// @Warmup(iterations = 3) +// @Measurement(iterations = 10) +// @BenchmarkMode(Mode.SingleShotTime) +// @OutputTimeUnit(TimeUnit.NANOSECONDS) +// public class PluginBenchmarks { +// +// private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; +// private static final String FIELD_SERVER_ID = "SERVER_ID"; +// private static final String FIELD_SESSION_ID = "SESSION_ID"; +// private static final String CONNECTION_STRING = "jdbc:postgresql://my.domain.com"; +// private static final String PG_CONNECTION_STRING = +// "jdbc:aws-wrapper:postgresql://instance-0.XYZ.us-east-2.rds.amazonaws.com"; +// private static final String TEST_HOST = "instance-0"; +// private static final int TEST_PORT = 5432; +// private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host(TEST_HOST).port(TEST_PORT).build(); +// +// @Mock private StorageService mockStorageService; +// @Mock private MonitorService mockMonitorService; +// @Mock private PluginService mockPluginService; +// @Mock private TargetDriverDialect mockTargetDriverDialect; +// @Mock private Dialect mockDialect; +// @Mock private ConnectionPluginManager mockConnectionPluginManager; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock TelemetryCounter mockTelemetryCounter; +// @Mock TelemetryGauge mockTelemetryGauge; +// @Mock private HostListProviderService mockHostListProviderService; +// @Mock private PluginManagerService mockPluginManagerService; +// @Mock ConnectionProvider mockConnectionProvider; +// @Mock Connection mockConnection; +// @Mock Statement mockStatement; +// @Mock ResultSet mockResultSet; +// private AutoCloseable closeable; +// +// public static void main(String[] args) throws RunnerException { +// Options opt = new OptionsBuilder() +// .include(PluginBenchmarks.class.getSimpleName()) +// .addProfiler(GCProfiler.class) +// .detectJvmArgs() +// .build(); +// +// new Runner(opt).run(); +// } +// +// @Setup(Level.Iteration) +// public void setUpIteration() throws Exception { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean(), any())) +// .thenReturn(mockConnection); +// when(mockConnectionPluginManager.execute( +// any(), any(), any(), eq(JdbcMethod.CONNECTION_CREATESTATEMENT), any(), any())) +// .thenReturn(mockStatement); +// when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); +// // noinspection unchecked +// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); +// when(mockConnectionProvider.connect( +// anyString(), +// any(Dialect.class), +// any(TargetDriverDialect.class), +// any(HostSpec.class), +// any(Properties.class))).thenReturn(mockConnection); +// when(mockConnection.createStatement()).thenReturn(mockStatement); +// when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); +// when(mockResultSet.next()).thenReturn(true, true, false); +// when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); +// when(mockResultSet.getString(eq(FIELD_SERVER_ID))) +// .thenReturn("instance-0", "instance-1"); +// when(mockResultSet.getStatement()).thenReturn(mockStatement); +// when(mockStatement.getConnection()).thenReturn(mockConnection); +// when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); +// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); +// when(this.mockPluginService.getDialect()).thenReturn(mockDialect); +// } +// +// @TearDown(Level.Iteration) +// public void tearDownIteration() throws Exception { +// closeable.close(); +// } +// +// @Benchmark +// public void initAndReleaseBaseLine() { +// } +// +// @Benchmark +// public ConnectionWrapper initAndReleaseWithExecutionTimePlugin() throws SQLException { +// try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING)) { +// wrapper.releaseResources(); +// return wrapper; +// } +// } +// +// private ConnectionWrapper getConnectionWrapper(Properties props, String connString) throws SQLException { +// return new TestConnectionWrapper( +// props, +// connString, +// mockConnectionProvider, +// mockTargetDriverDialect, +// mockConnectionPluginManager, +// mockTelemetryFactory, +// mockPluginService, +// mockHostListProviderService, +// mockPluginManagerService, +// mockStorageService, +// mockMonitorService); +// } +// +// @Benchmark +// public ConnectionWrapper initAndReleaseWithAuroraHostListPlugin() throws SQLException { +// try (ConnectionWrapper wrapper = getConnectionWrapper(useAuroraHostListPlugin(), CONNECTION_STRING)) { +// wrapper.releaseResources(); +// return wrapper; +// } +// } +// +// @Benchmark +// public ConnectionWrapper initAndReleaseWithExecutionTimeAndAuroraHostListPlugins() throws SQLException { +// try (ConnectionWrapper wrapper = +// getConnectionWrapper(useExecutionTimeAndAuroraHostListPlugins(), CONNECTION_STRING)) { +// wrapper.releaseResources(); +// return wrapper; +// } +// } +// +// @Benchmark +// public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin() throws SQLException { +// try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { +// wrapper.releaseResources(); +// return wrapper; +// } +// } +// +// @Benchmark +// public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin() +// throws SQLException { +// try (ConnectionWrapper wrapper = +// getConnectionWrapper(useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { +// wrapper.releaseResources(); +// return wrapper; +// } +// } +// +// @Benchmark +// public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin_internalConnectionPools() throws SQLException { +// HikariPooledConnectionProvider provider = +// new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); +// Driver.setCustomConnectionProvider(provider); +// try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { +// wrapper.releaseResources(); +// ConnectionProviderManager.releaseResources(); +// Driver.resetCustomConnectionProvider(); +// return wrapper; +// } +// } +// +// @Benchmark +// public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin_internalConnectionPools() +// throws SQLException { +// HikariPooledConnectionProvider provider = +// new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); +// Driver.setCustomConnectionProvider(provider); +// try (ConnectionWrapper wrapper = getConnectionWrapper( +// useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { +// wrapper.releaseResources(); +// ConnectionProviderManager.releaseResources(); +// Driver.resetCustomConnectionProvider(); +// return wrapper; +// } +// } +// +// @Benchmark +// public Statement executeStatementBaseline() throws SQLException { +// try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); +// Statement statement = wrapper.createStatement()) { +// return statement; +// } +// } +// +// @Benchmark +// public ResultSet executeStatementWithExecutionTimePlugin() throws SQLException { +// try ( +// ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); +// Statement statement = wrapper.createStatement(); +// ResultSet resultSet = statement.executeQuery("some sql")) { +// return resultSet; +// } +// } +// +// @Benchmark +// public ResultSet executeStatementWithTelemetryDisabled() throws SQLException { +// try ( +// ConnectionWrapper wrapper = getConnectionWrapper(disabledTelemetry(), CONNECTION_STRING); +// Statement statement = wrapper.createStatement(); +// ResultSet resultSet = statement.executeQuery("some sql")) { +// return resultSet; +// } +// } +// +// @Benchmark +// public ResultSet executeStatementWithTelemetry() throws SQLException { +// try ( +// ConnectionWrapper wrapper = getConnectionWrapper(useTelemetry(), CONNECTION_STRING); +// Statement statement = wrapper.createStatement(); +// ResultSet resultSet = statement.executeQuery("some sql")) { +// return resultSet; +// } +// } +// +// Properties useExecutionTimePlugin() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "executionTime"); +// return properties; +// } +// +// Properties useAuroraHostListPlugin() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "auroraHostList"); +// return properties; +// } +// +// Properties useExecutionTimeAndAuroraHostListPlugins() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "executionTime,auroraHostList"); +// return properties; +// } +// +// Properties useReadWriteSplittingPlugin() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "readWriteSplitting"); +// return properties; +// } +// +// Properties useAuroraHostListAndReadWriteSplittingPlugin() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); +// return properties; +// } +// +// Properties useReadWriteSplittingPluginWithReaderLoadBalancing() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "readWriteSplitting"); +// properties.setProperty("loadBalanceReadOnlyTraffic", "true"); +// return properties; +// } +// +// Properties useAuroraHostListAndReadWriteSplittingPluginWithReaderLoadBalancing() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); +// properties.setProperty("loadBalanceReadOnlyTraffic", "true"); +// return properties; +// } +// +// Properties useTelemetry() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); +// properties.setProperty("enableTelemetry", "true"); +// properties.setProperty("telemetryMetricsBackend", "none"); +// properties.setProperty("telemetryTracesBackend", "none"); +// return properties; +// } +// +// Properties disabledTelemetry() { +// final Properties properties = new Properties(); +// properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); +// properties.setProperty("enableTelemetry", "false"); +// return properties; +// } +// } diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java index f238fbe06..2979a6555 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java @@ -24,7 +24,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.Set; import java.util.logging.Logger; import software.amazon.jdbc.ConnectionPlugin; @@ -59,18 +58,21 @@ public T execute(Class resultClass, Class excepti } @Override - public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable connectFunc) - throws SQLException { + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { LOGGER.finer(() -> String.format("connect=''%s''", hostSpec.getHost())); resources.add("connect"); return connectFunc.call(); } @Override - public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable forceConnectFunc) - throws SQLException { + public Connection forceConnect( + ConnectionContext connectionContext, + HostSpec hostSpec, + boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { LOGGER.finer(() -> String.format("forceConnect=''%s''", hostSpec.getHost())); resources.add("forceConnect"); return forceConnectFunc.call(); diff --git a/docs/development-guide/LoadablePlugins.md b/docs/development-guide/LoadablePlugins.md index 6dd713a73..0b0aec2b7 100644 --- a/docs/development-guide/LoadablePlugins.md +++ b/docs/development-guide/LoadablePlugins.md @@ -116,9 +116,12 @@ public class BadPlugin extends AbstractConnectionPlugin { return new HashSet<>(Collections.singletonList("*")); } - @Override - public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, boolean isInitialConnection, - JdbcCallable connectFunc) throws SQLException { + @Override + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { // Bad Practice #2: using driver-specific objects. // Not all drivers support the same configuration parameters. For instance, while MySQL Connector/J Supports "database", // PGJDBC uses "dbname" for database names. @@ -167,9 +170,12 @@ public class GoodExample extends AbstractConnectionPlugin { return jdbcMethodFunc.call(); } - @Override - public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, boolean isInitialConnection, - JdbcCallable connectFunc) throws SQLException { + @Override + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { if (PropertyDefinition.USER.getString(props) == null) { PropertyDefinition.TARGET_DRIVER_USER_PROPERTY_NAME.set(props, "defaultUser"); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java index 16c67469a..73b147b90 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java @@ -32,9 +32,7 @@ import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.targetdriverdialect.ConnectInfo; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.SlidingExpirationCache; @@ -79,31 +77,32 @@ public HostSpec getHostSpecByStrategy(@NonNull List hosts, @NonNull Ho } @Override - public Connection connect(@NonNull String protocol, @NonNull Dialect dialect, - @NonNull TargetDriverDialect targetDriverDialect, @NonNull HostSpec hostSpec, - @NonNull Properties props) throws SQLException { - final Properties copy = PropertyUtils.copyProperties(props); - dialect.prepareConnectProperties(copy, protocol, hostSpec); + public Connection connect( + @NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) throws SQLException { + Dialect dialect = connectionContext.getDbDialect(); + Properties props = connectionContext.getPropsCopy(); + dialect.prepareConnectProperties(props, connectionContext.getProtocol(), hostSpec); final ComboPooledDataSource ds = databasePools.computeIfAbsent( hostSpec.getUrl(), - (key) -> createDataSource(protocol, hostSpec, copy, targetDriverDialect), + (key) -> createDataSource(connectionContext, hostSpec, props), poolExpirationCheckNanos ); - ds.setPassword(copy.getProperty(PropertyDefinition.PASSWORD.name)); + ds.setPassword(props.getProperty(PropertyDefinition.PASSWORD.name)); return ds.getConnection(); } protected ComboPooledDataSource createDataSource( - @NonNull String protocol, + @NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec, - @NonNull Properties props, - TargetDriverDialect driverDialect) { + @NonNull Properties props) { ConnectInfo connectInfo; + try { - connectInfo = driverDialect.prepareConnectInfo(protocol, hostSpec, props); + connectInfo = connectionContext.getDriverDialect() + .prepareConnectInfo(connectionContext.getProtocol(), hostSpec, props); } catch (SQLException ex) { throw new RuntimeException(ex); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java index a7b71e170..07f10a783 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java @@ -21,7 +21,6 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.Set; import software.amazon.jdbc.util.connection.ConnectionContext; @@ -46,7 +45,7 @@ T execute( * Establishes a connection to the given host using the given driver protocol and properties. If a * non-default {@link ConnectionProvider} has been set with * {@link Driver#setCustomConnectionProvider(ConnectionProvider)} and - * {@link ConnectionProvider#acceptsUrl(String, HostSpec, Properties)} returns true for the given + * {@link ConnectionProvider#acceptsUrl(ConnectionContext, HostSpec)} returns true for the given * protocol, host, and properties, the connection will be created by the non-default * ConnectionProvider. Otherwise, the connection will be created by the default * ConnectionProvider. The default ConnectionProvider will be {@link DriverConnectionProvider} for @@ -54,9 +53,8 @@ T execute( * {@link DataSourceConnectionProvider} for connections requested via an * {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param driverProtocol the driver protocol that should be used to establish the connection + * @param connectionContext the connection info for the original connection * @param hostSpec the host details for the desired connection - * @param props the connection properties * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has * already established a physical connection in the past @@ -67,9 +65,8 @@ T execute( * host */ Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException; @@ -83,9 +80,8 @@ Connection connect( * requested via the {@link java.sql.DriverManager} and {@link DataSourceConnectionProvider} for * connections requested via an {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param driverProtocol the driver protocol that should be used to establish the connection + * @param connectionContext the connection info for the original connection. * @param hostSpec the host details for the desired connection - * @param props the connection properties * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has * already established a physical connection in the past @@ -96,9 +92,8 @@ Connection connect( * host */ Connection forceConnect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException; diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index cba4007fd..af9463d23 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -170,7 +170,7 @@ public void unlock() { *

The {@link DefaultConnectionPlugin} will always be initialized and attached as the last * connection plugin in the chain. * - * @param servicesContainer the service container for the services required by this class. + * @param servicesContainer the service container for the services required by this class. * @param props the configuration of the connection * @param pluginManagerService a reference to a plugin manager service * @param configurationProfile a profile configuration defined by the user @@ -354,7 +354,7 @@ public T execute( throw WrapperUtils.wrapExceptionIfNeeded( exceptionClass, new SQLException( - Messages.get("ConnectionPluginManager.invokedAgainstOldConnection", new Object[]{methodInvokeOn}))); + Messages.get("ConnectionPluginManager.invokedAgainstOldConnection", new Object[] {methodInvokeOn}))); } } @@ -371,7 +371,7 @@ public T execute( * Establishes a connection to the given host using the given driver protocol and properties. If a * non-default {@link ConnectionProvider} has been set with * {@link Driver#setCustomConnectionProvider(ConnectionProvider)} and - * {@link ConnectionProvider#acceptsUrl(String, HostSpec, Properties)} returns true for the given + * {@link ConnectionProvider#acceptsUrl(ConnectionContext, HostSpec)} returns true for the given * protocol, host, and properties, the connection will be created by the non-default * ConnectionProvider. Otherwise, the connection will be created by the default * ConnectionProvider. The default ConnectionProvider will be {@link DriverConnectionProvider} for @@ -379,9 +379,8 @@ public T execute( * {@link DataSourceConnectionProvider} for connections requested via an * {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param driverProtocol the driver protocol that should be used to establish the connection + * @param connectionContext the connection info for the original connection * @param hostSpec the host details for the desired connection - * @param props the connection properties * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has * already established a physical connection in the past @@ -391,9 +390,8 @@ public T execute( * host */ public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { @@ -405,7 +403,7 @@ public Connection connect( return executeWithSubscribedPlugins( JdbcMethod.CONNECT, (plugin, func) -> - plugin.connect(driverProtocol, hostSpec, props, isInitialConnection, func), + plugin.connect(connectionContext, hostSpec, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); }, @@ -430,9 +428,8 @@ public Connection connect( * requested via the {@link java.sql.DriverManager} and {@link DataSourceConnectionProvider} for * connections requested via an {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param driverProtocol the driver protocol that should be used to establish the connection + * @param connectionContext the connection info for the original connection. * @param hostSpec the host details for the desired connection - * @param props the connection properties * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has * already established a physical connection in the past @@ -442,9 +439,8 @@ public Connection connect( * host */ public Connection forceConnect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { @@ -453,7 +449,7 @@ public Connection forceConnect( return executeWithSubscribedPlugins( JdbcMethod.FORCECONNECT, (plugin, func) -> - plugin.forceConnect(driverProtocol, hostSpec, props, isInitialConnection, func), + plugin.forceConnect(connectionContext, hostSpec, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); }, diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java index ebb475e7d..09f419bd1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java @@ -22,9 +22,6 @@ import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.units.qual.N; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.connection.ConnectionContext; /** @@ -73,21 +70,12 @@ HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param protocol the connection protocol (example "jdbc:mysql://") - * @param dialect the database dialect - * @param targetDriverDialect the target driver dialect + * @param connectionContext the connection info for the original connection. * @param hostSpec the HostSpec containing the host-port information for the host to connect to - * @param props the Properties to use for the connection * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ - Connection connect( - @NonNull String protocol, - @NonNull Dialect dialect, - @NonNull TargetDriverDialect targetDriverDialect, - @NonNull HostSpec hostSpec, - @NonNull Properties props) - throws SQLException; + Connection connect(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) throws SQLException; String getTargetName(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java index 92e5259d5..89921dfa0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java @@ -207,23 +207,21 @@ public static void resetConnectionInitFunc() { public void initConnection( final @Nullable Connection connection, - final @NonNull String protocol, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props) throws SQLException { + final @NonNull ConnectionContext connectionContext, + final @NonNull HostSpec hostSpec) throws SQLException { final ConnectionInitFunc connectionInitFunc = Driver.getConnectionInitFunc(); if (connectionInitFunc == null) { return; } - connectionInitFunc.initConnection(connection, protocol, hostSpec, props); + connectionInitFunc.initConnection(connection, connectionContext, hostSpec); } public interface ConnectionInitFunc { void initConnection( final @Nullable Connection connection, - final @NonNull String protocol, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props) throws SQLException; + final @NonNull ConnectionContext connectionContext, + final @NonNull HostSpec hostSpec) throws SQLException; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java index e465722f7..e2a2ed6c5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java @@ -30,11 +30,8 @@ import javax.sql.DataSource; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.exceptions.SQLLoginException; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.WrapperUtils; @@ -95,23 +92,16 @@ public HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param protocol The connection protocol (example "jdbc:mysql://") + * @param connectionContext the connection info for the original connection. * @param hostSpec The HostSpec containing the host-port information for the host to connect to - * @param props The Properties to use for the connection * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ @Override public Connection connect( - final @NonNull String protocol, - final @NonNull Dialect dialect, - final @NonNull TargetDriverDialect targetDriverDialect, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props) - throws SQLException { - - final Properties copy = PropertyUtils.copyProperties(props); - dialect.prepareConnectProperties(copy, protocol, hostSpec); + final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec) throws SQLException { + final Properties copy = connectionContext.getPropsCopy(); + connectionContext.getDbDialect().prepareConnectProperties(copy, connectionContext.getProtocol(), hostSpec); Connection conn; @@ -120,7 +110,7 @@ public Connection connect( LOGGER.finest(() -> "Use a separate DataSource object to create a connection."); // use a new data source instance to instantiate a connection final DataSource ds = createDataSource(); - conn = this.openConnection(ds, protocol, targetDriverDialect, hostSpec, copy); + conn = this.openConnection(ds, connectionContext, hostSpec, copy); } else { @@ -129,7 +119,7 @@ public Connection connect( this.lock.lock(); LOGGER.finest(() -> "Use main DataSource object to create a connection."); try { - conn = this.openConnection(this.dataSource, protocol, targetDriverDialect, hostSpec, copy); + conn = this.openConnection(this.dataSource, connectionContext, hostSpec, copy); } finally { this.lock.unlock(); } @@ -144,16 +134,15 @@ public Connection connect( protected Connection openConnection( final @NonNull DataSource ds, - final @NonNull String protocol, - final @NonNull TargetDriverDialect targetDriverDialect, + final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec, final @NonNull Properties props) throws SQLException { final boolean enableGreenNodeReplacement = PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(props); try { - targetDriverDialect.prepareDataSource( + connectionContext.getDriverDialect().prepareDataSource( ds, - protocol, + connectionContext.getProtocol(), hostSpec, props); return ds.getConnection(); @@ -202,9 +191,9 @@ protected Connection openConnection( .host(fixedHost) .build(); - targetDriverDialect.prepareDataSource( + connectionContext.getDriverDialect().prepareDataSource( this.dataSource, - protocol, + connectionContext.getProtocol(), connectionHostSpec, props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java index e3dc6735b..62b27359a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java @@ -28,10 +28,8 @@ import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.exceptions.SQLLoginException; import software.amazon.jdbc.targetdriverdialect.ConnectInfo; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; @@ -91,31 +89,19 @@ public HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param protocol The connection protocol (example "jdbc:mysql://") - * @param dialect The database dialect - * @param targetDriverDialect The target driver dialect + * @param connectionContext the connection info for the original connection. * @param hostSpec The HostSpec containing the host-port information for the host to connect to - * @param props The Properties to use for the connection * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ @Override - public Connection connect( - final @NonNull String protocol, - final @NonNull Dialect dialect, - final @NonNull TargetDriverDialect targetDriverDialect, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props) + public Connection connect(final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec) throws SQLException { + final Properties copy = connectionContext.getPropsCopy(); + final ConnectInfo connectInfo = + connectionContext.getDriverDialect().prepareConnectInfo(connectionContext.getProtocol(), hostSpec, copy); - // LOGGER.finest(() -> PropertyUtils.logProperties( - // PropertyUtils.maskProperties(props), "Connecting with properties: \n")); - - final Properties copy = PropertyUtils.copyProperties(props); - dialect.prepareConnectProperties(copy, protocol, hostSpec); - - final ConnectInfo connectInfo = targetDriverDialect.prepareConnectInfo(protocol, hostSpec, copy); - + connectionContext.getDbDialect().prepareConnectProperties(copy, connectionContext.getProtocol(), hostSpec); LOGGER.finest(() -> "Connecting to " + connectInfo.url + PropertyUtils.logProperties( PropertyUtils.maskProperties(connectInfo.props), @@ -127,7 +113,7 @@ public Connection connect( } catch (Throwable throwable) { - if (!PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(props)) { + if (!PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(copy)) { throw throwable; } @@ -172,7 +158,8 @@ public Connection connect( .host(fixedHost) .build(); - final ConnectInfo fixedConnectInfo = targetDriverDialect.prepareConnectInfo(protocol, connectionHostSpec, copy); + final ConnectInfo fixedConnectInfo = connectionContext.getDriverDialect().prepareConnectInfo( + connectionContext.getProtocol(), connectionHostSpec, copy); LOGGER.finest(() -> "Connecting to " + fixedConnectInfo.url + " after correcting the hostname from " + originalHost diff --git a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java index 31de59c9a..1f68ce0cf 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java @@ -35,9 +35,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.cleanup.CanReleaseResources; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.targetdriverdialect.ConnectInfo; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Pair; import software.amazon.jdbc.util.PropertyUtils; @@ -207,7 +205,7 @@ public HikariPooledConnectionProvider( @Override public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { if (this.acceptsUrlFunc != null) { - return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionContext.getProps()); + return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionContext.getPropsCopy()); } final RdsUrlType urlType = rdsUtils.identifyRdsType(hostSpec.getHost()); @@ -240,18 +238,12 @@ public HostSpec getHostSpecByStrategy( } @Override - public Connection connect( - @NonNull String protocol, - @NonNull Dialect dialect, - @NonNull TargetDriverDialect targetDriverDialect, - @NonNull HostSpec hostSpec, - @NonNull Properties props) + public Connection connect(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) throws SQLException { - - final Properties copy = PropertyUtils.copyProperties(props); + final Properties propsCopy = connectionContext.getPropsCopy(); HostSpec connectionHostSpec = hostSpec; - if (PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(props) + if (PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(propsCopy) && rdsUtils.isRdsDns(hostSpec.getHost()) && rdsUtils.isGreenInstance(hostSpec.getHost())) { @@ -275,15 +267,16 @@ public Connection connect( } final HostSpec finalHostSpec = connectionHostSpec; - dialect.prepareConnectProperties(copy, protocol, finalHostSpec); + connectionContext.getDbDialect().prepareConnectProperties( + propsCopy, connectionContext.getProtocol(), finalHostSpec); final HikariDataSource ds = (HikariDataSource) HikariPoolsHolder.databasePools.computeIfAbsent( - Pair.create(hostSpec.getUrl(), getPoolKey(finalHostSpec, copy)), - (lambdaPoolKey) -> createHikariDataSource(protocol, finalHostSpec, copy, targetDriverDialect), + Pair.create(hostSpec.getUrl(), getPoolKey(finalHostSpec, propsCopy)), + (lambdaPoolKey) -> createHikariDataSource(connectionContext, finalHostSpec, propsCopy), poolExpirationCheckNanos ); - ds.setPassword(copy.getProperty(PropertyDefinition.PASSWORD.name)); + ds.setPassword(propsCopy.getProperty(PropertyDefinition.PASSWORD.name)); return ds.getConnection(); } @@ -308,29 +301,27 @@ public void releaseResources() { /** * Configures the default required settings for the internal connection pool. * - * @param config the {@link HikariConfig} to configure. By default, this method sets the - * jdbcUrl, exceptionOverrideClassName, username, and password. The - * HikariConfig passed to this method should be created via a - * {@link HikariPoolConfigurator}, which allows the user to specify any - * additional configuration properties. - * @param protocol the driver protocol that should be used to form connections - * @param hostSpec the host details used to form the connection - * @param connectionProps the connection properties - * @param targetDriverDialect the target driver dialect {@link TargetDriverDialect} + * @param config the {@link HikariConfig} to configure. By default, this method sets the + * jdbcUrl, exceptionOverrideClassName, username, and password. The + * HikariConfig passed to this method should be created via a + * {@link HikariPoolConfigurator}, which allows the user to specify any + * additional configuration properties. + * @param connectionContext the connection info for the original connection. + * @param hostSpec the host details used to form the connection + * @param connectionProps the connection properties */ protected void configurePool( final HikariConfig config, - final String protocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties connectionProps, - final @NonNull TargetDriverDialect targetDriverDialect) { + final Properties connectionProps) { final Properties copy = PropertyUtils.copyProperties(connectionProps); ConnectInfo connectInfo; try { - connectInfo = targetDriverDialect.prepareConnectInfo( - protocol, hostSpec, copy); + connectInfo = connectionContext.getDriverDialect().prepareConnectInfo( + connectionContext.getProtocol(), hostSpec, copy); } catch (SQLException ex) { throw new RuntimeException(ex); } @@ -379,7 +370,7 @@ public int getHostCount() { public Set getHosts() { return Collections.unmodifiableSet( HikariPoolsHolder.databasePools.getEntries().keySet().stream() - .map(poolKey -> (String) poolKey.getValue1()) + .map(Pair::getValue1) .collect(Collectors.toSet())); } @@ -388,7 +379,7 @@ public Set getHosts() { * * @return a set containing every key associated with an active connection pool */ - public Set getKeys() { + public Set> getKeys() { return HikariPoolsHolder.databasePools.getEntries().keySet(); } @@ -416,18 +407,17 @@ public void logConnections() { } HikariDataSource createHikariDataSource( - final String protocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, - final @NonNull TargetDriverDialect targetDriverDialect) { + final Properties props) { HikariConfig config = poolConfigurator.configurePool(hostSpec, props); - configurePool(config, protocol, hostSpec, props, targetDriverDialect); + configurePool(config, connectionContext, hostSpec, props); return new HikariDataSource(config); } // For testing purposes only - void setDatabasePools(SlidingExpirationCache connectionPools) { + void setDatabasePools(SlidingExpirationCache, AutoCloseable> connectionPools) { HikariPoolsHolder.databasePools = connectionPools; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/HikariPoolsHolder.java b/wrapper/src/main/java/software/amazon/jdbc/HikariPoolsHolder.java index 7eb2389e3..bafd526f5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HikariPoolsHolder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HikariPoolsHolder.java @@ -20,7 +20,7 @@ import software.amazon.jdbc.util.storage.SlidingExpirationCache; public class HikariPoolsHolder { - static SlidingExpirationCache databasePools = + static SlidingExpirationCache, AutoCloseable> databasePools = new SlidingExpirationCache<>( null, (hikariDataSource) -> { diff --git a/wrapper/src/main/java/software/amazon/jdbc/LeastConnectionsHostSelector.java b/wrapper/src/main/java/software/amazon/jdbc/LeastConnectionsHostSelector.java index 88b90624c..4a0851f05 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/LeastConnectionsHostSelector.java +++ b/wrapper/src/main/java/software/amazon/jdbc/LeastConnectionsHostSelector.java @@ -31,10 +31,10 @@ public class LeastConnectionsHostSelector implements HostSelector { public static final String STRATEGY_LEAST_CONNECTIONS = "leastConnections"; - private final SlidingExpirationCache databasePools; + private final SlidingExpirationCache, AutoCloseable> databasePools; public LeastConnectionsHostSelector( - SlidingExpirationCache databasePools) { + SlidingExpirationCache, AutoCloseable> databasePools) { this.databasePools = databasePools; } @@ -50,7 +50,7 @@ public HostSpec getHost( getNumConnections(hostSpec1, this.databasePools) - getNumConnections(hostSpec2, this.databasePools)) .collect(Collectors.toList()); - if (eligibleHosts.size() == 0) { + if (eligibleHosts.isEmpty()) { throw new SQLException(Messages.get("HostSelector.noHostsMatchingRole", new Object[]{role})); } @@ -59,10 +59,10 @@ public HostSpec getHost( private int getNumConnections( final HostSpec hostSpec, - final SlidingExpirationCache databasePools) { + final SlidingExpirationCache, AutoCloseable> databasePools) { int numConnections = 0; final String url = hostSpec.getUrl(); - for (final Map.Entry entry : + for (final Map.Entry, AutoCloseable> entry : databasePools.getEntries().entrySet()) { if (!url.equals(entry.getKey().getValue1())) { continue; diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index df034895d..74fe058cf 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -48,6 +48,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.CacheMap; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -65,8 +66,8 @@ public class PartialPluginService implements PluginService, CanReleaseResources, protected static final CacheMap hostAvailabilityExpiringCache = new CacheMap<>(); protected final FullServicesContainer servicesContainer; + protected final ConnectionContext connectionContext; protected final ConnectionPluginManager pluginManager; - protected final Properties props; protected volatile HostListProvider hostListProvider; protected List allHosts = new ArrayList<>(); protected HostSpec currentHostSpec; @@ -74,39 +75,22 @@ public class PartialPluginService implements PluginService, CanReleaseResources, protected boolean isInTransaction; protected final ExceptionManager exceptionManager; protected final @Nullable ExceptionHandler exceptionHandler; - protected final String originalUrl; - protected final String driverProtocol; - protected TargetDriverDialect targetDriverDialect; - protected Dialect dbDialect; protected @Nullable final ConfigurationProfile configurationProfile; protected final ConnectionProviderManager connectionProviderManager; public PartialPluginService( - @NonNull final FullServicesContainer servicesContainer, - @NonNull final Properties props, - @NonNull final String originalUrl, - @NonNull final String targetDriverProtocol, - @NonNull final TargetDriverDialect targetDriverDialect, - @NonNull final Dialect dbDialect) { + @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionContext connectionContext) { this( servicesContainer, new ExceptionManager(), - props, - originalUrl, - targetDriverProtocol, - targetDriverDialect, - dbDialect, + connectionContext, null); } public PartialPluginService( @NonNull final FullServicesContainer servicesContainer, @NonNull final ExceptionManager exceptionManager, - @NonNull final Properties props, - @NonNull final String originalUrl, - @NonNull final String targetDriverProtocol, - @NonNull final TargetDriverDialect targetDriverDialect, - @NonNull final Dialect dbDialect, + @NonNull final ConnectionContext connectionContext, @Nullable final ConfigurationProfile configurationProfile) { this.servicesContainer = servicesContainer; this.servicesContainer.setHostListProviderService(this); @@ -114,11 +98,7 @@ public PartialPluginService( this.servicesContainer.setPluginManagerService(this); this.pluginManager = servicesContainer.getConnectionPluginManager(); - this.props = props; - this.originalUrl = originalUrl; - this.driverProtocol = targetDriverProtocol; - this.targetDriverDialect = targetDriverDialect; - this.dbDialect = dbDialect; + this.connectionContext = connectionContext; this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; @@ -130,8 +110,8 @@ public PartialPluginService( ? this.configurationProfile.getExceptionHandler() : null; - HostListProviderSupplier supplier = this.dbDialect.getHostListProvider(); - this.hostListProvider = supplier.getProvider(this.props, this.originalUrl, this.servicesContainer); + HostListProviderSupplier supplier = this.connectionContext.getDbDialect().getHostListProvider(); + this.hostListProvider = supplier.getProvider(this.connectionContext, this.servicesContainer); } @Override @@ -182,9 +162,14 @@ public HostSpec getInitialConnectionHostSpec() { return this.initialConnectionHostSpec; } + @Override + public ConnectionContext getConnectionContext() { + return this.connectionContext; + } + @Override public String getOriginalUrl() { - return this.originalUrl; + return this.connectionContext.getUrl(); } @Override @@ -229,13 +214,13 @@ public ConnectionProvider getDefaultConnectionProvider() { public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = - this.connectionProviderManager.getConnectionProvider(this.driverProtocol, host, props); + this.connectionProviderManager.getConnectionProvider(this.connectionContext, host); return (connectionProvider instanceof PooledConnectionProvider); } @Override public String getDriverProtocol() { - return this.driverProtocol; + return this.connectionContext.getProtocol(); } @Override @@ -519,8 +504,7 @@ public Connection forceConnect( final Properties props, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { - return this.pluginManager.forceConnect( - this.driverProtocol, hostSpec, props, true, pluginToSkip); + return this.pluginManager.forceConnect(this.connectionContext, hostSpec,true, pluginToSkip); } private void updateHostAvailability(final List hosts) { @@ -540,7 +524,7 @@ public void releaseResources() { @Override public boolean isNetworkException(Throwable throwable) { - return this.isNetworkException(throwable, this.targetDriverDialect); + return this.isNetworkException(throwable, this.connectionContext.getDriverDialect()); } @Override @@ -548,7 +532,9 @@ public boolean isNetworkException(final Throwable throwable, @Nullable TargetDri if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(throwable, targetDriverDialect); } - return this.exceptionManager.isNetworkException(this.dbDialect, throwable, targetDriverDialect); + + return this.exceptionManager.isNetworkException( + this.connectionContext.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -556,12 +542,13 @@ public boolean isNetworkException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(sqlState); } - return this.exceptionManager.isNetworkException(this.dbDialect, sqlState); + + return this.exceptionManager.isNetworkException(this.connectionContext.getDbDialect(), sqlState); } @Override public boolean isLoginException(Throwable throwable) { - return this.isLoginException(throwable, this.targetDriverDialect); + return this.isLoginException(throwable, this.connectionContext.getDriverDialect()); } @Override @@ -569,7 +556,9 @@ public boolean isLoginException(final Throwable throwable, @Nullable TargetDrive if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(throwable, targetDriverDialect); } - return this.exceptionManager.isLoginException(this.dbDialect, throwable, targetDriverDialect); + + return this.exceptionManager.isLoginException( + this.connectionContext.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -577,17 +566,17 @@ public boolean isLoginException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(sqlState); } - return this.exceptionManager.isLoginException(this.dbDialect, sqlState); + return this.exceptionManager.isLoginException(this.connectionContext.getDbDialect(), sqlState); } @Override public Dialect getDialect() { - return this.dbDialect; + return this.connectionContext.getDbDialect(); } @Override public TargetDriverDialect getTargetDriverDialect() { - return this.targetDriverDialect; + return this.connectionContext.getDriverDialect(); } @Override @@ -635,12 +624,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.props)); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getPropsCopy())); } @Override public Properties getProperties() { - return this.props; + return this.connectionContext.getPropsCopy(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index f00a75777..007b7043b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -125,7 +125,7 @@ public PluginServiceImpl( this.sessionStateService = sessionStateService != null ? sessionStateService - : new SessionStateServiceImpl(this, this.connectionContext.getProps()); + : new SessionStateServiceImpl(this, this.connectionContext.getPropsCopy()); this.exceptionHandler = this.configurationProfile != null && this.configurationProfile.getExceptionHandler() != null ? this.configurationProfile.getExceptionHandler() @@ -291,7 +291,7 @@ public EnumSet setCurrentConnection( this.setInTransaction(false); if (isInTransaction - && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionContext.getProps())) { + && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionContext.getPropsCopy())) { try { oldConnection.rollback(); } catch (final SQLException e) { @@ -592,7 +592,7 @@ public Connection connect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.connect( - this.connectionContext.getProtocol(), hostSpec, props, this.currentConnection == null, pluginToSkip); + this.connectionContext, hostSpec, this.currentConnection == null, pluginToSkip); } @Override @@ -610,7 +610,7 @@ public Connection forceConnect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.forceConnect( - this.connectionContext.getProtocol(), hostSpec, props, this.currentConnection == null, pluginToSkip); + this.connectionContext, hostSpec, this.currentConnection == null, pluginToSkip); } private void updateHostAvailability(final List hosts) { @@ -747,12 +747,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getProps())); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getPropsCopy())); } @Override public Properties getProperties() { - return this.connectionContext.getProps(); + return this.connectionContext.getPropsCopy(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java index e64352899..ede526f9b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java @@ -89,21 +89,20 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> { + return (connectionContext, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsHostListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, IS_READER_QUERY, IS_WRITER_QUERY); } + return new AuroraHostListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java index ad21fb9f9..e00cdc3df 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java @@ -135,21 +135,20 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> { + return (connectionContext, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsHostListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, IS_READER_QUERY, IS_WRITER_QUERY); } + return new AuroraHostListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java index aa24188fe..0eecadbfa 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java @@ -131,7 +131,7 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th return this.dialect; } - final String userDialectSetting = DIALECT.getString(connectionContext.getProps()); + final String userDialectSetting = DIALECT.getString(connectionContext.getPropsCopy()); final String dialectCode = !StringUtils.isNullOrEmpty(userDialectSetting) ? userDialectSetting : knownEndpointDialects.get(connectionContext.getUrl()); diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java index 3b368a8a1..52e104c07 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java @@ -104,8 +104,8 @@ public List getDialectUpdateCandidates() { } public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> - new ConnectionStringHostListProvider(properties, initialUrl, servicesContainer.getHostListProviderService()); + return (connectionContext, servicesContainer) -> + new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java index 87c4c705c..300436604 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java @@ -106,8 +106,8 @@ public List getDialectUpdateCandidates() { @Override public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> - new ConnectionStringHostListProvider(properties, initialUrl, servicesContainer.getHostListProviderService()); + return (connectionContext, servicesContainer) -> + new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java index 930cf1631..756591474 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java @@ -94,12 +94,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> { + return (connectionContext, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsMultiAzHostListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -109,8 +108,7 @@ public HostListProviderSupplier getHostListProvider() { } else { return new RdsMultiAzDbClusterListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java index d91c867c0..23e9d6b78 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java @@ -80,12 +80,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> { + return (connectionContext, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsMultiAzHostListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -96,8 +95,7 @@ public HostListProviderSupplier getHostListProvider() { } else { return new RdsMultiAzDbClusterListProvider( - properties, - initialUrl, + connectionContext, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java index 65b9eb544..e1ecc937e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java @@ -81,8 +81,8 @@ public List getDialectUpdateCandidates() { @Override public HostListProviderSupplier getHostListProvider() { - return (properties, initialUrl, servicesContainer) -> - new ConnectionStringHostListProvider(properties, initialUrl, servicesContainer.getHostListProviderService()); + return (connectionContext, servicesContainer) -> + new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java index fc53f9e1d..ab1bac786 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java @@ -17,9 +17,9 @@ package software.amazon.jdbc.hostlistprovider; -import java.util.Properties; import java.util.logging.Logger; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; public class AuroraHostListProvider extends RdsHostListProvider { @@ -27,14 +27,13 @@ public class AuroraHostListProvider extends RdsHostListProvider { static final Logger LOGGER = Logger.getLogger(AuroraHostListProvider.class.getName()); public AuroraHostListProvider( - final Properties properties, - final String originalUrl, + final ConnectionContext connectionContext, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery) { - super(properties, - originalUrl, + super( + connectionContext, servicesContainer, topologyQuery, nodeIdQuery, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java index f088ef864..5a4d4fc9c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java @@ -21,11 +21,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Properties; import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.AwsWrapperProperty; -import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; @@ -62,7 +60,7 @@ public ConnectionStringHostListProvider( final @NonNull HostListProviderService hostListProviderService, final @NonNull ConnectionUrlParser connectionUrlParser) { this.connectionContext = connectionContext; - this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionContext.getProps()); + this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionContext.getPropsCopy()); this.connectionUrlParser = connectionUrlParser; this.hostListProviderService = hostListProviderService; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java index 738eebcc3..0883ff909 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java @@ -55,6 +55,7 @@ import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.SynchronousExecutor; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.CacheMap; public class RdsHostListProvider implements DynamicHostListProvider { @@ -94,7 +95,7 @@ public class RdsHostListProvider implements DynamicHostListProvider { protected final FullServicesContainer servicesContainer; protected final HostListProviderService hostListProviderService; - protected final String originalUrl; + protected final ConnectionContext connectionContext; protected final String topologyQuery; protected final String nodeIdQuery; protected final String isReaderQuery; @@ -123,14 +124,12 @@ public class RdsHostListProvider implements DynamicHostListProvider { } public RdsHostListProvider( - final Properties properties, - final String originalUrl, + final ConnectionContext connectionContext, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery) { - this.properties = properties; - this.originalUrl = originalUrl; + this.connectionContext = connectionContext; this.servicesContainer = servicesContainer; this.hostListProviderService = servicesContainer.getHostListProviderService(); this.topologyQuery = topologyQuery; @@ -151,11 +150,11 @@ protected void init() throws SQLException { // initial topology is based on connection string this.initialHostList = - connectionUrlParser.getHostsFromConnectionUrl(this.originalUrl, false, + connectionUrlParser.getHostsFromConnectionUrl(this.connectionContext.getUrl(), false, this.hostListProviderService::getHostSpecBuilder); if (this.initialHostList == null || this.initialHostList.isEmpty()) { throw new SQLException(Messages.get("RdsHostListProvider.parsedListEmpty", - new Object[] {this.originalUrl})); + new Object[] {this.connectionContext.getUrl()})); } this.initialHostSpec = this.initialHostList.get(0); this.hostListProviderService.setInitialConnectionHostSpec(this.initialHostSpec); @@ -297,9 +296,7 @@ protected ClusterSuggestedResult getSuggestedClusterId(final String url) { if (key.equals(url)) { return new ClusterSuggestedResult(url, isPrimaryCluster); } - if (hosts == null) { - continue; - } + for (final HostSpec host : hosts) { if (host.getHostAndPort().equals(url)) { LOGGER.finest(() -> Messages.get("RdsHostListProvider.suggestedClusterId", diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java index a63323176..663b024f2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java @@ -26,13 +26,13 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; -import java.util.Properties; import java.util.logging.Logger; import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.connection.ConnectionContext; public class RdsMultiAzDbClusterListProvider extends RdsHostListProvider { private final String fetchWriterNodeQuery; @@ -40,8 +40,7 @@ public class RdsMultiAzDbClusterListProvider extends RdsHostListProvider { static final Logger LOGGER = Logger.getLogger(RdsMultiAzDbClusterListProvider.class.getName()); public RdsMultiAzDbClusterListProvider( - final Properties properties, - final String originalUrl, + final ConnectionContext connectionContext, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, @@ -49,8 +48,8 @@ public RdsMultiAzDbClusterListProvider( final String fetchWriterNodeQuery, final String fetchWriterNodeQueryHeader ) { - super(properties, - originalUrl, + super( + connectionContext, servicesContainer, topologyQuery, nodeIdQuery, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java index c37a95bba..d2aacbab1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java @@ -510,11 +510,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.servicesContainer.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.servicesContainer.getPluginService().getOriginalUrl(), - this.servicesContainer.getPluginService().getDriverProtocol(), - this.servicesContainer.getPluginService().getTargetDriverDialect(), - this.servicesContainer.getPluginService().getDialect(), - this.properties + this.servicesContainer.getPluginService().getConnectionContext() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java index 5fd965006..97d2f7f0b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java @@ -19,7 +19,6 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.List; -import java.util.Properties; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.logging.Logger; @@ -32,6 +31,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.Topology; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; @@ -56,14 +56,13 @@ public class MonitoringRdsHostListProvider extends RdsHostListProvider protected final String writerTopologyQuery; public MonitoringRdsHostListProvider( - final Properties properties, - final String originalUrl, + final ConnectionContext connectionContext, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery, final String writerTopologyQuery) { - super(properties, originalUrl, servicesContainer, topologyQuery, nodeIdQuery, isReaderQuery); + super(connectionContext, servicesContainer, topologyQuery, nodeIdQuery, isReaderQuery); this.servicesContainer = servicesContainer; this.pluginService = servicesContainer.getPluginService(); this.writerTopologyQuery = writerTopologyQuery; @@ -87,11 +86,7 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.originalUrl, - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.properties, + this.connectionContext, (servicesContainer) -> new ClusterTopologyMonitorImpl( this.servicesContainer, this.clusterId, @@ -132,11 +127,7 @@ protected void clusterIdChanged(final String oldClusterId) throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.originalUrl, - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.properties, + this.connectionContext, (servicesContainer) -> existingMonitor); assert monitorService.get(ClusterTopologyMonitorImpl.class, this.clusterId) == existingMonitor; existingMonitor.setClusterId(this.clusterId); diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java index 730d15e7a..6df75253a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java @@ -17,9 +17,9 @@ package software.amazon.jdbc.hostlistprovider.monitoring; import java.sql.SQLException; -import java.util.Properties; import java.util.logging.Logger; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; public class MonitoringRdsMultiAzHostListProvider extends MonitoringRdsHostListProvider { @@ -29,8 +29,7 @@ public class MonitoringRdsMultiAzHostListProvider extends MonitoringRdsHostListP protected final String fetchWriterNodeColumnName; public MonitoringRdsMultiAzHostListProvider( - final Properties properties, - final String originalUrl, + final ConnectionContext connectionContext, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, @@ -38,8 +37,7 @@ public MonitoringRdsMultiAzHostListProvider( final String fetchWriterNodeQuery, final String fetchWriterNodeColumnName) { super( - properties, - originalUrl, + connectionContext, servicesContainer, topologyQuery, nodeIdQuery, @@ -51,22 +49,18 @@ public MonitoringRdsMultiAzHostListProvider( @Override protected ClusterTopologyMonitor initMonitor() throws SQLException { - return this.servicesContainer.getMonitorService().runIfAbsent(MultiAzClusterTopologyMonitorImpl.class, + return this.servicesContainer.getMonitorService().runIfAbsent( + MultiAzClusterTopologyMonitorImpl.class, this.clusterId, this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.originalUrl, - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.properties, + this.connectionContext, (servicesContainer) -> new MultiAzClusterTopologyMonitorImpl( servicesContainer, this.clusterId, this.initialHostSpec, this.properties, - this.hostListProviderService, this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java index 36bab8f90..6ff7e9677 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MultiAzClusterTopologyMonitorImpl.java @@ -24,7 +24,6 @@ import java.time.Instant; import java.util.Properties; import java.util.logging.Logger; -import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.StringUtils; @@ -41,7 +40,6 @@ public MultiAzClusterTopologyMonitorImpl( final String clusterId, final HostSpec initialHostSpec, final Properties properties, - final HostListProviderService hostListProviderService, final HostSpec clusterInstanceTemplate, final long refreshRateNano, final long highRefreshRateNano, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java index 035e4ecf9..0d704c53d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java @@ -21,7 +21,6 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.Set; import software.amazon.jdbc.ConnectionPlugin; import software.amazon.jdbc.HostListProviderService; @@ -30,6 +29,7 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.OldConnectionSuggestedAction; +import software.amazon.jdbc.util.connection.ConnectionContext; public abstract class AbstractConnectionPlugin implements ConnectionPlugin { @@ -50,9 +50,8 @@ public T execute( @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -61,9 +60,8 @@ public Connection connect( @Override public Connection forceConnect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { @@ -89,9 +87,7 @@ public HostSpec getHostSpecByStrategy(final List hosts, final HostRole @Override public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties props, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java index 9b42d33ea..19454b941 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java @@ -36,6 +36,7 @@ import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionContext; public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin { @@ -82,8 +83,11 @@ public Set getSubscribedMethods() { } @Override - public Connection connect(final String driverProtocol, final HostSpec hostSpec, final Properties props, - final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { final Connection conn = connectFunc.call(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java index 01b9bf8e9..6d8e6d558 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java @@ -39,6 +39,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlugin { @@ -97,10 +98,10 @@ public static VerifyOpenedConnectionType fromValue(String value) { } private final PluginService pluginService; - private HostListProviderService hostListProviderService; private final RdsUtils rdsUtils = new RdsUtils(); + private final VerifyOpenedConnectionType verifyOpenedConnectionType; + private HostListProviderService hostListProviderService; - private VerifyOpenedConnectionType verifyOpenedConnectionType = null; static { PropertyDefinition.registerPluginProperties(AuroraInitialConnectionStrategyPlugin.class); @@ -119,9 +120,7 @@ public Set getSubscribedMethods() { @Override public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties props, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -134,15 +133,13 @@ public void initHostProvider( @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - final RdsUrlType type = this.rdsUtils.identifyRdsType(hostSpec.getHost()); - + final Properties props = connectionContext.getPropsCopy(); if (type == RdsUrlType.RDS_WRITER_CLUSTER || isInitialConnection && this.verifyOpenedConnectionType == VerifyOpenedConnectionType.WRITER) { Connection writerCandidateConn = this.getVerifiedWriterConnection(props, isInitialConnection, connectFunc); @@ -369,8 +366,6 @@ private HostSpec getReader(final Properties props) throws SQLException { if (this.pluginService.acceptsStrategy(HostRole.READER, strategy)) { try { return this.pluginService.getHostSpecByStrategy(HostRole.READER, strategy); - } catch (UnsupportedOperationException ex) { - throw ex; } catch (SQLException ex) { // host isn't found return null; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java index e5986cc23..14960187f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java @@ -50,6 +50,7 @@ import software.amazon.jdbc.util.Pair; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -184,13 +185,12 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, props, connectFunc); + return connectInternal(hostSpec, connectionContext.getPropsCopy(), connectFunc); } private Connection connectInternal(HostSpec hostSpec, Properties props, @@ -226,13 +226,12 @@ private Connection connectInternal(HostSpec hostSpec, Properties props, @Override public Connection forceConnect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, props, forceConnectFunc); + return connectInternal(hostSpec, connectionContext.getPropsCopy(), forceConnectFunc); } /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java index 4eaea61fe..f5b071d44 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java @@ -21,12 +21,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; -import java.util.Properties; import java.util.Set; import java.util.logging.Logger; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.connection.ConnectionContext; public class ConnectTimeConnectionPlugin extends AbstractConnectionPlugin { @@ -43,8 +43,11 @@ public Set getSubscribedMethods() { } @Override - public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { + public Connection connect( + ConnectionContext connectionContext, + HostSpec hostSpec, + boolean isInitialConnection, + JdbcCallable connectFunc) throws SQLException { final long startTime = System.nanoTime(); final Connection result = connectFunc.call(); @@ -59,9 +62,11 @@ public Connection connect(String driverProtocol, HostSpec hostSpec, Properties p } @Override - public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable forceConnectFunc) - throws SQLException { + public Connection forceConnect( + ConnectionContext connectionContext, + HostSpec hostSpec, + boolean isInitialConnection, + JdbcCallable forceConnectFunc) throws SQLException { final long startTime = System.nanoTime(); final Connection result = forceConnectFunc.call(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java index 32a1ad509..7253df836 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java @@ -25,7 +25,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.Set; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -46,6 +45,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlMethodAnalyzer; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; @@ -74,27 +74,15 @@ public DefaultConnectionPlugin( final PluginManagerService pluginManagerService) { this(pluginService, defaultConnProvider, - effectiveConnProvider, pluginManagerService, new ConnectionProviderManager(defaultConnProvider, effectiveConnProvider)); } public DefaultConnectionPlugin( - final PluginService pluginService, - final ConnectionProvider defaultConnProvider, - final @Nullable ConnectionProvider effectiveConnProvider, - final PluginManagerService pluginManagerService, - final ConnectionProviderManager connProviderManager) { - if (pluginService == null) { - throw new IllegalArgumentException("pluginService"); - } - if (pluginManagerService == null) { - throw new IllegalArgumentException("pluginManagerService"); - } - if (defaultConnProvider == null) { - throw new IllegalArgumentException("connectionProvider"); - } - + final @NonNull PluginService pluginService, + final @NonNull ConnectionProvider defaultConnProvider, + final @NonNull PluginManagerService pluginManagerService, + final @NonNull ConnectionProviderManager connProviderManager) { this.pluginService = pluginService; this.pluginManagerService = pluginManagerService; this.defaultConnProvider = defaultConnProvider; @@ -166,47 +154,37 @@ public T execute( @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { + final JdbcCallable connectFunc) throws SQLException { - ConnectionProvider connProvider = this.connProviderManager.getConnectionProvider(driverProtocol, hostSpec, props); + ConnectionProvider connProvider = this.connProviderManager.getConnectionProvider(connectionContext, hostSpec); // It's guaranteed that this plugin is always the last in plugin chain so connectFunc can be // ignored. - return connectInternal(driverProtocol, hostSpec, props, connProvider, isInitialConnection); + return connectInternal(connectionContext, hostSpec, connProvider, isInitialConnection); } private Connection connectInternal( - String driverProtocol, - HostSpec hostSpec, - Properties props, - ConnectionProvider connProvider, - final boolean isInitialConnection) - throws SQLException { - + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final ConnectionProvider connProvider, + final boolean isInitialConnection) throws SQLException { TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext( connProvider.getTargetName(), TelemetryTraceLevel.NESTED); Connection conn; try { - conn = connProvider.connect( - driverProtocol, - this.pluginService.getDialect(), - this.pluginService.getTargetDriverDialect(), - hostSpec, - props); + conn = connProvider.connect(connectionContext, hostSpec); } finally { if (telemetryContext != null) { telemetryContext.closeContext(); } } - this.connProviderManager.initConnection(conn, driverProtocol, hostSpec, props); + this.connProviderManager.initConnection(conn, connectionContext, hostSpec); this.pluginService.setAvailability(hostSpec.asAliases(), HostAvailability.AVAILABLE); if (isInitialConnection) { @@ -218,16 +196,15 @@ private Connection connectInternal( @Override public Connection forceConnect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { // It's guaranteed that this plugin is always the last in plugin chain so forceConnectFunc can be // ignored. - return connectInternal(driverProtocol, hostSpec, props, this.defaultConnProvider, isInitialConnection); + return connectInternal(connectionContext, hostSpec, this.defaultConnProvider, isInitialConnection); } @Override @@ -265,9 +242,7 @@ public HostSpec getHostSpecByStrategy(final List hosts, final HostRole @Override public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties props, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java index 7bfe67b35..be72e8b32 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java @@ -42,6 +42,7 @@ import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -126,9 +127,8 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java index a3a717638..0fc3310a8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java @@ -57,6 +57,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; public class BlueGreenStatusMonitor { @@ -598,7 +599,8 @@ protected void initHostListProvider() { return; } - final Properties hostListProperties = PropertyUtils.copyProperties(this.props); + final ConnectionContext originalContext = this.pluginService.getConnectionContext(); + final Properties hostListProperties = originalContext.getPropsCopy(); // Need to instantiate a separate HostListProvider with // a special unique clusterId to avoid interference with other HostListProviders opened for this cluster. @@ -610,16 +612,14 @@ protected void initHostListProvider() { LOGGER.finest(() -> Messages.get("bgd.createHostListProvider", new Object[] {this.role, RdsHostListProvider.CLUSTER_ID.getString(hostListProperties)})); - String protocol = this.connectionUrlParser.getProtocol(this.pluginService.getOriginalUrl()); final HostSpec connectionHostSpecCopy = this.connectionHostSpec.get(); if (connectionHostSpecCopy != null) { - String hostListProviderUrl = String.format("%s%s/", protocol, connectionHostSpecCopy.getHostAndPort()); - this.hostListProvider = this.pluginService.getDialect() - .getHostListProvider() - .getProvider( - hostListProperties, - hostListProviderUrl, - this.servicesContainer); + String hostListProviderUrl = + String.format("%s%s/", originalContext.getProtocol(), connectionHostSpecCopy.getHostAndPort()); + ConnectionContext newContext = new ConnectionContext( + hostListProviderUrl, originalContext.getProtocol(), originalContext.getDriverDialect(), hostListProperties); + this.hostListProvider = + this.pluginService.getDialect().getHostListProvider().getProvider(newContext, this.servicesContainer); } else { LOGGER.warning(() -> Messages.get("bgd.hostSpecNull")); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java index 5a272e922..1ade11648 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java @@ -41,6 +41,7 @@ import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.monitoring.MonitorErrorResponse; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -163,12 +164,10 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { + final JdbcCallable connectFunc) throws SQLException { if (!this.rdsUtils.isRdsCustomClusterDns(hostSpec.getHost())) { return connectFunc.call(); } @@ -218,11 +217,7 @@ protected CustomEndpointMonitor createMonitorIfAbsent(Properties props) throws S this.servicesContainer.getStorageService(), this.pluginService.getTelemetryFactory(), this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.props, + this.pluginService.getConnectionContext(), (servicesContainer) -> new CustomEndpointMonitorImpl( servicesContainer.getStorageService(), servicesContainer.getTelemetryFactory(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java index 5335b444c..80b61d8c4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java @@ -29,6 +29,7 @@ import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; public class DeveloperConnectionPlugin extends AbstractConnectionPlugin implements ExceptionSimulator { @@ -142,34 +143,29 @@ protected void raiseException( @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - - this.raiseExceptionOnConnectIfNeeded(driverProtocol, hostSpec, props, isInitialConnection); - return super.connect(driverProtocol, hostSpec, props, isInitialConnection, connectFunc); + final JdbcCallable connectFunc) throws SQLException { + this.raiseExceptionOnConnectIfNeeded(connectionContext, hostSpec, isInitialConnection); + return super.connect(connectionContext, hostSpec, isInitialConnection, connectFunc); } @Override public Connection forceConnect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { - this.raiseExceptionOnConnectIfNeeded(driverProtocol, hostSpec, props, isInitialConnection); - return super.connect(driverProtocol, hostSpec, props, isInitialConnection, forceConnectFunc); + this.raiseExceptionOnConnectIfNeeded(connectionContext, hostSpec, isInitialConnection); + return super.connect(connectionContext, hostSpec, isInitialConnection, forceConnectFunc); } protected void raiseExceptionOnConnectIfNeeded( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection) throws SQLException { @@ -178,10 +174,7 @@ protected void raiseExceptionOnConnectIfNeeded( } else if (ExceptionSimulatorManager.connectCallback != null) { this.raiseExceptionOnConnect( ExceptionSimulatorManager.connectCallback.getExceptionToRaise( - driverProtocol, - hostSpec, - props, - isInitialConnection)); + connectionContext, hostSpec, isInitialConnection)); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java index aff02da98..39f6a2453 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java @@ -17,13 +17,12 @@ package software.amazon.jdbc.plugin.dev; import java.sql.SQLException; -import java.util.Properties; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.util.connection.ConnectionContext; public interface ExceptionSimulatorConnectCallback { SQLException getExceptionToRaise( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java index be9f56746..493d2fe60 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java @@ -40,6 +40,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; /** * Monitor the server while the connection is executing methods for more sophisticated failure @@ -269,12 +270,10 @@ public OldConnectionSuggestedAction notifyConnectionChanged(final EnumSet connectFunc) - throws SQLException { + final JdbcCallable connectFunc) throws SQLException { final Connection conn = connectFunc.call(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java index e0853e2d9..8e661d23b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java @@ -156,11 +156,7 @@ protected HostMonitor getMonitor( this.serviceContainer.getStorageService(), this.telemetryFactory, this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.pluginService.getProperties(), + this.pluginService.getConnectionContext(), (servicesContainer) -> new HostMonitorImpl( servicesContainer, hostSpec, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java index 0e7472939..698349a25 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java @@ -41,6 +41,7 @@ import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; /** * Monitor the server while the connection is executing methods for more sophisticated failure @@ -227,12 +228,10 @@ public OldConnectionSuggestedAction notifyConnectionChanged(final EnumSet connectFunc) - throws SQLException { + final JdbcCallable connectFunc) throws SQLException { final Connection conn = connectFunc.call(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index e58378ddb..38f1333e5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -364,11 +364,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.props + this.pluginService.getConnectionContext() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 5fade66de..81cb525af 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -174,11 +174,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.initialConnectionProps + this.pluginService.getConnectionContext() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index ef2f95550..987f09a91 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -53,6 +53,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -301,9 +302,7 @@ public T execute( @Override public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties properties, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -905,19 +904,16 @@ private boolean canDirectExecute(final String methodName) { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { + final JdbcCallable connectFunc) throws SQLException { this.initFailoverMode(); Connection conn = null; try { - conn = - this.staleDnsHelper.getVerifiedConnection(isInitialConnection, this.hostListProviderService, - driverProtocol, hostSpec, props, connectFunc); + conn = this.staleDnsHelper.getVerifiedConnection( + isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); } catch (final SQLException e) { if (!this.enableConnectFailover || !shouldExceptionTriggerConnectionSwitch(e)) { throw e; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java index 76a0ec244..7eb64085b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java @@ -53,6 +53,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -232,9 +233,7 @@ public T execute( @Override public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties properties, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -728,20 +727,17 @@ protected void initFailoverMode() { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - + final JdbcCallable connectFunc) throws SQLException { this.initFailoverMode(); Connection conn = null; - + Properties props = connectionContext.getPropsCopy(); if (!ENABLE_CONNECT_FAILOVER.getBoolean(props)) { - return this.staleDnsHelper.getVerifiedConnection(isInitialConnection, this.hostListProviderService, - driverProtocol, hostSpec, props, connectFunc); + return this.staleDnsHelper.getVerifiedConnection( + isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); } final HostSpec hostSpecWithAvailability = this.pluginService.getHosts().stream() @@ -753,8 +749,8 @@ public Connection connect( || hostSpecWithAvailability.getAvailability() != HostAvailability.NOT_AVAILABLE) { try { - conn = this.staleDnsHelper.getVerifiedConnection(isInitialConnection, this.hostListProviderService, - driverProtocol, hostSpec, props, connectFunc); + conn = this.staleDnsHelper.getVerifiedConnection( + isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); } catch (final SQLException e) { if (!this.shouldExceptionTriggerConnectionSwitch(e)) { throw e; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java index 6a06e1d68..f40e4be72 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -42,6 +42,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -152,27 +153,26 @@ public FederatedAuthPlugin(final PluginService pluginService, @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - return connectInternal(hostSpec, props, connectFunc); + final JdbcCallable connectFunc) throws SQLException { + return connectInternal(hostSpec, connectionContext.getPropsCopy(), connectFunc); } @Override public Connection forceConnect( - final @NonNull String driverProtocol, + final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec, - final @NonNull Properties props, final boolean isInitialConnection, final @NonNull JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, props, forceConnectFunc); + return connectInternal(hostSpec, connectionContext.getPropsCopy(), forceConnectFunc); } - private Connection connectInternal(final HostSpec hostSpec, final Properties props, + private Connection connectInternal( + final HostSpec hostSpec, + final Properties props, final JdbcCallable connectFunc) throws SQLException { this.samlUtils.checkIdpCredentialsWithFallback(IDP_USERNAME, IDP_PASSWORD, props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java index 7e47ff479..b09b0d9e1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java @@ -40,6 +40,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -132,16 +133,22 @@ public Set getSubscribedMethods() { } @Override - public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, props, connectFunc); + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { + return connectInternal(hostSpec, connectionContext.getPropsCopy(), connectFunc); } @Override - public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable forceConnectFunc) + public Connection forceConnect( + ConnectionContext connectionContext, + HostSpec hostSpec, + boolean isInitialConnection, + JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, props, forceConnectFunc); + return connectInternal(hostSpec, connectionContext.getPropsCopy(), forceConnectFunc); } private Connection connectInternal(final HostSpec hostSpec, final Properties props, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java index 5541ac917..64ac50ba7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java @@ -40,6 +40,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -106,17 +107,18 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - return connectInternal(driverProtocol, hostSpec, props, connectFunc); + final JdbcCallable connectFunc) throws SQLException { + return connectInternal(connectionContext, hostSpec, connectFunc); } - private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Properties props, + private Connection connectInternal( + ConnectionContext connectionContext, + HostSpec hostSpec, JdbcCallable connectFunc) throws SQLException { + Properties props = connectionContext.getPropsCopy(); if (StringUtils.isNullOrEmpty(PropertyDefinition.USER.getString(props))) { throw new SQLException(PropertyDefinition.USER.name + " is null or empty."); } @@ -224,13 +226,12 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro @Override public Connection forceConnect( - final @NonNull String driverProtocol, + final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec, - final @NonNull Properties props, final boolean isInitialConnection, final @NonNull JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(driverProtocol, hostSpec, props, forceConnectFunc); + return connectInternal(connectionContext, hostSpec, forceConnectFunc); } public static void clearCache() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java index 2bff9680c..be21a576d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java @@ -35,6 +35,7 @@ import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.connection.ConnectionContext; public class LimitlessConnectionPlugin extends AbstractConnectionPlugin { @@ -103,12 +104,10 @@ public LimitlessConnectionPlugin( @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { + final JdbcCallable connectFunc) throws SQLException { Connection conn = null; @@ -131,7 +130,7 @@ public Connection connect( final LimitlessConnectionContext context = new LimitlessConnectionContext( hostSpec, - props, + connectionContext.getPropsCopy(), conn, connectFunc, null, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java index 4582acaf6..db2a27bb6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java @@ -157,12 +157,11 @@ public void establishConnection(final LimitlessConnectionContext context) throws if (this.isLoginException(e)) { throw e; } - if (selectedHostSpec != null) { - LOGGER.fine(Messages.get( - "LimitlessRouterServiceImpl.failedToConnectToHost", - new Object[] {selectedHostSpec.getHost()})); - selectedHostSpec.setAvailability(HostAvailability.NOT_AVAILABLE); - } + + LOGGER.fine(Messages.get( + "LimitlessRouterServiceImpl.failedToConnectToHost", + new Object[] {selectedHostSpec.getHost()})); + selectedHostSpec.setAvailability(HostAvailability.NOT_AVAILABLE); // Retry connect prioritising the healthiest router for best chance of // connection over load-balancing with round-robin. retryConnectWithLeastLoadedRouters(context); @@ -315,10 +314,7 @@ protected boolean isLoginException(Throwable throwable) { } @Override - public void startMonitoring(final @NonNull HostSpec hostSpec, - final @NonNull Properties props, - final int intervalMs) { - + public void startMonitoring(final @NonNull HostSpec hostSpec, final @NonNull Properties props, final int intervalMs) { try { final String limitlessRouterMonitorKey = pluginService.getHostListProvider().getClusterId(); this.servicesContainer.getMonitorService().runIfAbsent( @@ -327,11 +323,7 @@ public void startMonitoring(final @NonNull HostSpec hostSpec, this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - props, + this.pluginService.getConnectionContext(), (servicesContainer) -> new LimitlessRouterMonitor( servicesContainer, hostSpec, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index 2d25675e2..e4cdcc760 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -44,6 +44,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionContext; public class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implements CanReleaseResources { @@ -124,25 +125,20 @@ public Set getSubscribedMethods() { @Override public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties props, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { - this.hostListProviderService = hostListProviderService; initHostProviderFunc.call(); } @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final @NonNull JdbcCallable connectFunc) - throws SQLException { + final JdbcCallable connectFunc) throws SQLException { if (!pluginService.acceptsStrategy(hostSpec.getRole(), this.readerSelectorStrategy)) { throw new UnsupportedOperationException( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java index 682c3080f..4043edea3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java @@ -23,7 +23,6 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.logging.Logger; import software.amazon.jdbc.HostListProviderService; import software.amazon.jdbc.HostRole; @@ -34,6 +33,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -61,9 +61,8 @@ public AuroraStaleDnsHelper(final PluginService pluginService) { public Connection getVerifiedConnection( final boolean isInitialConnection, final HostListProviderService hostListProviderService, - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final JdbcCallable connectFunc) throws SQLException { if (!this.rdsUtils.isWriterClusterDns(hostSpec.getHost())) { @@ -148,7 +147,7 @@ public Connection getVerifiedConnection( ); } - final Connection writerConn = this.pluginService.connect(this.writerHostSpec, props); + final Connection writerConn = this.pluginService.connect(this.writerHostSpec, connectionContext.getPropsCopy()); if (isInitialConnection) { hostListProviderService.setInitialConnectionHostSpec(this.writerHostSpec); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java index a5babfc3c..fb8eca976 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java @@ -32,6 +32,7 @@ import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.util.connection.ConnectionContext; /** * After Aurora DB cluster fail over is completed and a cluster has elected a new writer node, the corresponding @@ -75,21 +76,17 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - return this.helper.getVerifiedConnection(isInitialConnection, this.hostListProviderService, - driverProtocol, hostSpec, props, connectFunc); + final JdbcCallable connectFunc) throws SQLException { + return this.helper.getVerifiedConnection( + isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); } @Override public void initHostProvider( - final String driverProtocol, - final String initialUrl, - final Properties props, + final ConnectionContext connectionContext, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { this.hostListProviderService = hostListProviderService; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java index f9ae41b6d..e35963955 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java @@ -40,6 +40,7 @@ import software.amazon.jdbc.RandomHostSelector; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.CacheMap; public class FastestResponseStrategyPlugin extends AbstractConnectionPlugin { @@ -111,13 +112,10 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final String driverProtocol, + final ConnectionContext connectionContext, final HostSpec hostSpec, - final Properties props, final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - + final JdbcCallable connectFunc) throws SQLException { Connection conn = connectFunc.call(); if (isInitialConnection) { this.hostResponseTimeService.setHosts(this.pluginService.getHosts()); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java index ee157b3ea..6e15bfe0c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java @@ -79,11 +79,7 @@ public void setHosts(final @NonNull List hosts) { servicesContainer.getStorageService(), servicesContainer.getTelemetryFactory(), servicesContainer.getDefaultConnectionProvider(), - this.pluginService.getOriginalUrl(), - this.pluginService.getDriverProtocol(), - this.pluginService.getTargetDriverDialect(), - this.pluginService.getDialect(), - this.props, + this.pluginService.getConnectionContext(), (servicesContainer) -> new NodeResponseTimeMonitor(pluginService, hostSpec, this.props, this.intervalMs)); } catch (SQLException e) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java index eae8e72b2..6a5c4a191 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java @@ -17,13 +17,11 @@ package software.amazon.jdbc.util; import java.sql.SQLException; -import java.util.Properties; import java.util.concurrent.locks.ReentrantLock; import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.PartialPluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -60,31 +58,20 @@ public FullServicesContainer createServiceContainer( MonitorService monitorService, ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, - String originalUrl, - String targetDriverProtocol, - TargetDriverDialect driverDialect, - Dialect dbDialect, - Properties props) throws SQLException { + ConnectionContext connectionContext) throws SQLException { FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, connectionProvider, telemetryFactory); ConnectionPluginManager pluginManager = new ConnectionPluginManager( connectionProvider, null, null, telemetryFactory); servicesContainer.setConnectionPluginManager(pluginManager); - PartialPluginService partialPluginService = new PartialPluginService( - servicesContainer, - props, - originalUrl, - targetDriverProtocol, - driverDialect, - dbDialect - ); + PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, connectionContext); servicesContainer.setHostListProviderService(partialPluginService); servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); - pluginManager.init(servicesContainer, props, partialPluginService, null); + pluginManager.init(servicesContainer, connectionContext.getPropsCopy(), partialPluginService, null); return servicesContainer; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java index 20f1c6669..9223a3853 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java @@ -23,7 +23,7 @@ import software.amazon.jdbc.util.PropertyUtils; public class ConnectionContext { - protected final static ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); + protected static final ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); protected final String url; protected final String protocol; protected final TargetDriverDialect driverDialect; @@ -31,8 +31,12 @@ public class ConnectionContext { protected Dialect dbDialect; public ConnectionContext(String url, TargetDriverDialect driverDialect, Properties props) { + this(url, connectionUrlParser.getProtocol(url), driverDialect, props); + } + + public ConnectionContext(String url, String protocol, TargetDriverDialect driverDialect, Properties props) { this.url = url; - this.protocol = connectionUrlParser.getProtocol(url); + this.protocol = protocol; this.driverDialect = driverDialect; this.props = props; } @@ -49,7 +53,7 @@ public TargetDriverDialect getDriverDialect() { return driverDialect; } - public Properties getProps() { + public Properties getPropsCopy() { return PropertyUtils.copyProperties(props); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 9e8356eb9..e9232cbad 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -24,11 +24,8 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PartialPluginService; import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.FullServicesContainerImpl; -import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -42,7 +39,7 @@ */ @Deprecated public class ConnectionServiceImpl implements ConnectionService { - protected final String targetDriverProtocol; + protected final ConnectionContext connectionContext; protected final ConnectionPluginManager pluginManager; protected final PluginService pluginService; @@ -57,12 +54,8 @@ public ConnectionServiceImpl( MonitorService monitorService, TelemetryFactory telemetryFactory, ConnectionProvider connectionProvider, - String originalUrl, - String targetDriverProtocol, - TargetDriverDialect driverDialect, - Dialect dbDialect, - Properties props) throws SQLException { - this.targetDriverProtocol = targetDriverProtocol; + ConnectionContext connectionContext) throws SQLException { + this.connectionContext = connectionContext; FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, connectionProvider, telemetryFactory); @@ -72,29 +65,20 @@ public ConnectionServiceImpl( null, telemetryFactory); servicesContainer.setConnectionPluginManager(this.pluginManager); - - Properties propsCopy = PropertyUtils.copyProperties(props); - PartialPluginService partialPluginService = new PartialPluginService( - servicesContainer, - propsCopy, - originalUrl, - this.targetDriverProtocol, - driverDialect, - dbDialect - ); + PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, this.connectionContext); servicesContainer.setHostListProviderService(partialPluginService); servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); this.pluginService = partialPluginService; - this.pluginManager.init(servicesContainer, propsCopy, partialPluginService, null); + this.pluginManager.init(servicesContainer, this.connectionContext.getPropsCopy(), partialPluginService, null); } @Override @Deprecated public Connection open(HostSpec hostSpec, Properties props) throws SQLException { - return this.pluginManager.forceConnect(this.targetDriverProtocol, hostSpec, props, true, null); + return this.pluginManager.forceConnect(this.connectionContext, hostSpec, true, null); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java index 6360d8acd..975fa3f8d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java @@ -17,12 +17,10 @@ package software.amazon.jdbc.util.monitoring; import java.sql.SQLException; -import java.util.Properties; import java.util.Set; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -61,11 +59,7 @@ void registerMonitorTypeIfAbsent( * @param telemetryFactory the telemetry factory for creating telemetry data. * @param defaultConnectionProvider the connection provider to use to create new connections if the monitor * requires it. - * @param originalUrl the URL of the original database connection. - * @param driverProtocol the protocol for the underlying target driver. - * @param driverDialect the target driver dialect. - * @param dbDialect the database dialect. - * @param originalProps the properties of the original database connection. + * @param connectionContext the connection info for the original connection. * @param initializer an initializer function to use to create the monitor if it does not already exist. * @param the type of the monitor. * @return the new or existing monitor. @@ -77,11 +71,7 @@ T runIfAbsent( StorageService storageService, TelemetryFactory telemetryFactory, ConnectionProvider defaultConnectionProvider, - String originalUrl, - String driverProtocol, - TargetDriverDialect driverDialect, - Dialect dbDialect, - Properties originalProps, + ConnectionContext connectionContext, MonitorInitializer initializer) throws SQLException; /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index 5e0df4817..018f66120 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -33,17 +33,15 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostlistprovider.Topology; import software.amazon.jdbc.hostlistprovider.monitoring.ClusterTopologyMonitorImpl; import software.amazon.jdbc.hostlistprovider.monitoring.MultiAzClusterTopologyMonitorImpl; import software.amazon.jdbc.plugin.strategy.fastestresponse.NodeResponseTimeMonitor; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.ExecutorFactory; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.ServiceUtility; +import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.events.DataAccessEvent; import software.amazon.jdbc.util.events.Event; import software.amazon.jdbc.util.events.EventPublisher; @@ -182,11 +180,7 @@ public T runIfAbsent( StorageService storageService, TelemetryFactory telemetryFactory, ConnectionProvider defaultConnectionProvider, - String originalUrl, - String driverProtocol, - TargetDriverDialect driverDialect, - Dialect dbDialect, - Properties originalProps, + ConnectionContext connectionContext, MonitorInitializer initializer) throws SQLException { CacheContainer cacheContainer = monitorCaches.get(monitorClass); if (cacheContainer == null) { @@ -208,11 +202,7 @@ public T runIfAbsent( storageService, defaultConnectionProvider, telemetryFactory, - originalUrl, - driverProtocol, - driverDialect, - dbDialect, - originalProps); + connectionContext); final MonitorItem monitorItemInner = new MonitorItem(() -> initializer.createMonitor(servicesContainer)); monitorItemInner.getMonitor().start(); return monitorItemInner; @@ -239,22 +229,13 @@ protected FullServicesContainer getNewServicesContainer( StorageService storageService, ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, - String originalUrl, - String driverProtocol, - TargetDriverDialect driverDialect, - Dialect dbDialect, - Properties originalProps) throws SQLException { - final Properties propsCopy = PropertyUtils.copyProperties(originalProps); + ConnectionContext connectionContext) throws SQLException { return ServiceUtility.getInstance().createServiceContainer( storageService, this, connectionProvider, telemetryFactory, - originalUrl, - driverProtocol, - driverDialect, - dbDialect, - propsCopy + connectionContext ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java index 714e2700a..e01fa6ea6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java @@ -51,7 +51,6 @@ import software.amazon.jdbc.dialect.HostListProviderSupplier; import software.amazon.jdbc.profile.ConfigurationProfile; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.ConnectionUrlParser; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.FullServicesContainerImpl; import software.amazon.jdbc.util.Messages; @@ -167,14 +166,8 @@ protected void init(final Properties props, final FullServicesContainer services this.pluginManager.initHostProvider(this.connectionContext, this.hostListProviderService); this.pluginService.refreshHostList(); if (this.pluginService.getCurrentConnection() == null) { - final Connection conn = - this.pluginManager.connect( - this.connectionContext.getProtocol(), - this.pluginService.getInitialConnectionHostSpec(), - props, - true, - null); - + final Connection conn = this.pluginManager.connect( + this.connectionContext, this.pluginService.getInitialConnectionHostSpec(), true, null); if (conn == null) { throw new SQLException(Messages.get("ConnectionWrapper.connectionNotOpen"), SqlState.UNKNOWN_STATE.getState()); } diff --git a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java index c35f6b0f8..f7216d85d 100644 --- a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java +++ b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java @@ -1,33 +1,33 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package integration.container.aurora; - -import java.util.Properties; -import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; -import software.amazon.jdbc.util.FullServicesContainer; - -public class TestAuroraHostListProvider extends AuroraHostListProvider { - - public TestAuroraHostListProvider( - FullServicesContainer servicesContainer, Properties properties, String originalUrl) { - super(properties, originalUrl, servicesContainer, "", "", ""); - } - - public static void clearCache() { - AuroraHostListProvider.clearAll(); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package integration.container.aurora; +// +// import java.util.Properties; +// import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; +// import software.amazon.jdbc.util.FullServicesContainer; +// +// public class TestAuroraHostListProvider extends AuroraHostListProvider { +// +// public TestAuroraHostListProvider( +// FullServicesContainer servicesContainer, Properties properties, String originalUrl) { +// super(properties, originalUrl, servicesContainer, "", "", ""); +// } +// +// public static void clearCache() { +// AuroraHostListProvider.clearAll(); +// } +// } diff --git a/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java b/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java index b356eb391..1d597570a 100644 --- a/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java +++ b/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java @@ -17,23 +17,17 @@ package integration.container.aurora; import java.sql.SQLException; -import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.PluginServiceImpl; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; public class TestPluginServiceImpl extends PluginServiceImpl { public TestPluginServiceImpl( - @NonNull FullServicesContainer servicesContainer, - @NonNull Properties props, - @NonNull String originalUrl, - String targetDriverProtocol, - @NonNull final TargetDriverDialect targetDriverDialect) + @NonNull FullServicesContainer servicesContainer, @NonNull ConnectionContext connectionContext) throws SQLException { - - super(servicesContainer, props, originalUrl, targetDriverProtocol, targetDriverDialect); + super(servicesContainer, connectionContext); } public static void clearHostAvailabilityCache() { diff --git a/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java b/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java index ae966c1d1..3fbbbb762 100644 --- a/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java +++ b/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java @@ -1,767 +1,767 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package integration.container.tests; - -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import static software.amazon.jdbc.PropertyDefinition.CONNECT_TIMEOUT; -import static software.amazon.jdbc.PropertyDefinition.PLUGINS; -import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_COUNT; -import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_INTERVAL; -import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_TIME; -import static software.amazon.jdbc.plugin.failover.FailoverConnectionPlugin.FAILOVER_TIMEOUT_MS; - -import integration.TestEnvironmentFeatures; -import integration.container.ConnectionStringHelper; -import integration.container.TestDriverProvider; -import integration.container.TestEnvironment; -import integration.container.aurora.TestAuroraHostListProvider; -import integration.container.aurora.TestPluginServiceImpl; -import integration.container.condition.DisableOnTestFeature; -import integration.container.condition.EnableOnTestFeature; -import integration.util.AuroraTestUtility; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.net.InetAddress; -import java.net.UnknownHostException; -import java.sql.Connection; -import java.sql.DriverManager; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; -import java.util.logging.Logger; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.poi.ss.usermodel.Cell; -import org.apache.poi.ss.usermodel.Row; -import org.apache.poi.xssf.usermodel.XSSFSheet; -import org.apache.poi.xssf.usermodel.XSSFWorkbook; -import org.junit.jupiter.api.MethodOrderer; -import org.junit.jupiter.api.Order; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.TestMethodOrder; -import org.junit.jupiter.api.TestTemplate; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.provider.Arguments; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; -import software.amazon.jdbc.plugin.efm2.HostMonitorServiceImpl; -import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; -import software.amazon.jdbc.util.StringUtils; - -@TestMethodOrder(MethodOrderer.MethodName.class) -@ExtendWith(TestDriverProvider.class) -@EnableOnTestFeature({ - TestEnvironmentFeatures.PERFORMANCE, - TestEnvironmentFeatures.FAILOVER_SUPPORTED -}) -@DisableOnTestFeature(TestEnvironmentFeatures.RUN_DB_METRICS_ONLY) -@Tag("advanced") -@Order(1) -public class AdvancedPerformanceTest { - - private static final Logger LOGGER = Logger.getLogger(AdvancedPerformanceTest.class.getName()); - - private static final String MONITORING_CONNECTION_PREFIX = "monitoring-"; - - private static final int REPEAT_TIMES = - StringUtils.isNullOrEmpty(System.getenv("REPEAT_TIMES")) - ? 5 - : Integer.parseInt(System.getenv("REPEAT_TIMES")); - - private static final int TIMEOUT_SEC = 5; - private static final int CONNECT_TIMEOUT_SEC = 5; - private static final int EFM_FAILOVER_TIMEOUT_MS = 300000; - private static final int EFM_FAILURE_DETECTION_TIME_MS = 30000; - private static final int EFM_FAILURE_DETECTION_INTERVAL_MS = 5000; - private static final int EFM_FAILURE_DETECTION_COUNT = 3; - private static final String QUERY = "SELECT pg_sleep(600)"; // 600s -> 10min - - private static final ConcurrentLinkedQueue perfDataList = new ConcurrentLinkedQueue<>(); - - protected static final AuroraTestUtility auroraUtil = AuroraTestUtility.getUtility(); - - private static void doWritePerfDataToFile( - String fileName, ConcurrentLinkedQueue dataList) throws IOException { - - if (dataList.isEmpty()) { - return; - } - - LOGGER.finest(() -> "File name: " + fileName); - - List sortedData = - dataList.stream() - .sorted( - (d1, d2) -> - d1.paramFailoverDelayMillis == d2.paramFailoverDelayMillis - ? d1.paramDriverName.compareTo(d2.paramDriverName) - : 0) - .collect(Collectors.toList()); - - try (XSSFWorkbook workbook = new XSSFWorkbook()) { - - final XSSFSheet sheet = workbook.createSheet("PerformanceResults"); - - for (int rows = 0; rows < dataList.size(); rows++) { - PerfStat perfStat = sortedData.get(rows); - Row row; - - if (rows == 0) { - // Header - row = sheet.createRow(0); - perfStat.writeHeader(row); - } - - row = sheet.createRow(rows + 1); - perfStat.writeData(row); - } - - // Write to file - final File newExcelFile = new File(fileName); - newExcelFile.createNewFile(); - try (FileOutputStream fileOut = new FileOutputStream(newExcelFile)) { - workbook.write(fileOut); - } - } - } - - @TestTemplate - public void test_AdvancedPerformance() throws IOException { - - perfDataList.clear(); - - try { - Stream argsStream = generateParams(); - argsStream.forEach( - a -> { - try { - ensureClusterHealthy(); - LOGGER.finest("DB cluster is healthy."); - ensureDnsHealthy(); - LOGGER.finest("DNS is healthy."); - - Object[] args = a.get(); - int failoverDelayTimeMillis = (int) args[0]; - int runNumber = (int) args[1]; - - LOGGER.finest( - "Iteration " - + runNumber - + "/" - + REPEAT_TIMES - + " for " - + failoverDelayTimeMillis - + "ms delay"); - - doMeasurePerformance(failoverDelayTimeMillis); - - } catch (InterruptedException ex) { - throw new RuntimeException(ex); - } catch (UnknownHostException e) { - throw new RuntimeException(e); - } - }); - - } finally { - doWritePerfDataToFile( - String.format( - "./build/reports/tests/AdvancedPerformanceResults_" - + "Db_%s_Driver_%s_Instances_%d.xlsx", - TestEnvironment.getCurrent().getInfo().getRequest().getDatabaseEngine(), - TestEnvironment.getCurrent().getCurrentDriver(), - TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances()), - perfDataList); - perfDataList.clear(); - } - } - - private void doMeasurePerformance(int sleepDelayMillis) throws InterruptedException { - - final AtomicLong downtimeNano = new AtomicLong(); - final CountDownLatch startLatch = new CountDownLatch(5); - final CountDownLatch finishLatch = new CountDownLatch(5); - - downtimeNano.set(0); - - final Thread failoverThread = - getThread_Failover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); - final Thread pgThread = - getThread_DirectDriver(sleepDelayMillis, downtimeNano, startLatch, finishLatch); - final Thread wrapperEfmThread = - getThread_WrapperEfm(sleepDelayMillis, downtimeNano, startLatch, finishLatch); - final Thread wrapperEfmFailoverThread = - getThread_WrapperEfmFailover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); - final Thread dnsThread = getThread_DNS(sleepDelayMillis, downtimeNano, startLatch, finishLatch); - - failoverThread.start(); - pgThread.start(); - wrapperEfmThread.start(); - wrapperEfmFailoverThread.start(); - dnsThread.start(); - - LOGGER.finest("All threads started."); - - finishLatch.await(5, TimeUnit.MINUTES); // wait for all threads to complete - - LOGGER.finest("Test is over."); - - assertTrue(downtimeNano.get() > 0); - - failoverThread.interrupt(); - pgThread.interrupt(); - wrapperEfmThread.interrupt(); - wrapperEfmFailoverThread.interrupt(); - dnsThread.interrupt(); - } - - private void ensureDnsHealthy() throws UnknownHostException, InterruptedException { - LOGGER.finest( - "Writer is " - + TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getInstances() - .get(0) - .getInstanceId()); - final String writerIpAddress = - InetAddress.getByName( - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getInstances() - .get(0) - .getHost()) - .getHostAddress(); - LOGGER.finest("Writer resolves to " + writerIpAddress); - LOGGER.finest( - "Cluster Endpoint is " - + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()); - String clusterIpAddress = - InetAddress.getByName( - TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) - .getHostAddress(); - LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); - - long startTimeNano = System.nanoTime(); - while (!clusterIpAddress.equals(writerIpAddress) - && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { - Thread.sleep(1000); - clusterIpAddress = - InetAddress.getByName( - TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) - .getHostAddress(); - LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); - } - - if (!clusterIpAddress.equals(writerIpAddress)) { - fail("DNS has stale data"); - } - } - - private Thread getThread_Failover( - final int sleepDelayMillis, - final AtomicLong downtimeNano, - final CountDownLatch startLatch, - final CountDownLatch finishLatch) { - - return new Thread( - () -> { - try { - Thread.sleep(1000); - startLatch.countDown(); // notify that this thread is ready for work - startLatch.await( - 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test - - LOGGER.finest("Waiting " + sleepDelayMillis + "ms..."); - Thread.sleep(sleepDelayMillis); - LOGGER.finest("Trigger failover..."); - - // trigger failover - failoverCluster(); - downtimeNano.set(System.nanoTime()); - LOGGER.finest("Failover is started."); - - } catch (InterruptedException interruptedException) { - // Ignore, stop the thread - } catch (Exception exception) { - fail("Failover thread exception: " + exception); - } finally { - finishLatch.countDown(); - LOGGER.finest("Failover thread is completed."); - } - }); - } - - private Thread getThread_DirectDriver( - final int sleepDelayMillis, - final AtomicLong downtimeNano, - final CountDownLatch startLatch, - final CountDownLatch finishLatch) { - - return new Thread( - () -> { - long failureTimeNano = 0; - try { - // DB_CONN_STR_PREFIX - final Properties props = ConnectionStringHelper.getDefaultProperties(); - final Connection conn = - openConnectionWithRetry( - ConnectionStringHelper.getUrl( - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpoint(), - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpointPort(), - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getDefaultDbName()), - props); - LOGGER.finest("DirectDriver connection is open."); - - Thread.sleep(1000); - startLatch.countDown(); // notify that this thread is ready for work - startLatch.await( - 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test - - LOGGER.finest("DirectDriver Starting long query..."); - // Execute long query - final Statement statement = conn.createStatement(); - try (final ResultSet result = statement.executeQuery(QUERY)) { - fail("Sleep query finished, should not be possible with network downed."); - } catch (SQLException throwable) { // Catching executing query - LOGGER.finest("DirectDriver thread exception: " + throwable); - // Calculate and add detection time - assertTrue(downtimeNano.get() > 0); - failureTimeNano = System.nanoTime() - downtimeNano.get(); - } - - } catch (InterruptedException interruptedException) { - // Ignore, stop the thread - } catch (Exception exception) { - fail("PG thread exception: " + exception); - } finally { - PerfStat data = new PerfStat(); - data.paramFailoverDelayMillis = sleepDelayMillis; - data.paramDriverName = - "DirectDriver - " + TestEnvironment.getCurrent().getCurrentDriver(); - data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); - LOGGER.finest("DirectDriver Collected data: " + data); - perfDataList.add(data); - LOGGER.finest( - "DirectDriver Failure detection time is " + data.failureDetectionTimeMillis + "ms"); - - finishLatch.countDown(); - LOGGER.finest("DirectDriver thread is completed."); - } - }); - } - - private Thread getThread_WrapperEfm( - final int sleepDelayMillis, - final AtomicLong downtimeNano, - final CountDownLatch startLatch, - final CountDownLatch finishLatch) { - - return new Thread( - () -> { - long failureTimeNano = 0; - try { - final Properties props = ConnectionStringHelper.getDefaultProperties(); - - props.setProperty( - MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, - String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); - props.setProperty( - MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, - String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); - CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); - - FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); - FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_INTERVAL_MS)); - FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); - PLUGINS.set(props, "efm"); - - final Connection conn = - openConnectionWithRetry( - ConnectionStringHelper.getWrapperUrl( - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpoint(), - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpointPort(), - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getDefaultDbName()), - props); - LOGGER.finest("WrapperEfm connection is open."); - - Thread.sleep(1000); - startLatch.countDown(); // notify that this thread is ready for work - startLatch.await( - 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test - - LOGGER.finest("WrapperEfm Starting long query..."); - // Execute long query - final Statement statement = conn.createStatement(); - try (final ResultSet result = statement.executeQuery(QUERY)) { - fail("Sleep query finished, should not be possible with network downed."); - } catch (SQLException throwable) { // Catching executing query - LOGGER.finest("WrapperEfm thread exception: " + throwable); - - // Calculate and add detection time - assertTrue(downtimeNano.get() > 0); - failureTimeNano = System.nanoTime() - downtimeNano.get(); - } - - } catch (InterruptedException interruptedException) { - // Ignore, stop the thread - } catch (Exception exception) { - fail("WrapperEfm thread exception: " + exception); - } finally { - PerfStat data = new PerfStat(); - data.paramFailoverDelayMillis = sleepDelayMillis; - data.paramDriverName = - String.format( - "AWS Wrapper (%s, EFM)", TestEnvironment.getCurrent().getCurrentDriver()); - data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); - LOGGER.finest("WrapperEfm Collected data: " + data); - perfDataList.add(data); - LOGGER.finest( - "WrapperEfm Failure detection time is " + data.failureDetectionTimeMillis + "ms"); - - finishLatch.countDown(); - LOGGER.finest("WrapperEfm thread is completed."); - } - }); - } - - private Thread getThread_WrapperEfmFailover( - final int sleepDelayMillis, - final AtomicLong downtimeNano, - final CountDownLatch startLatch, - final CountDownLatch finishLatch) { - - return new Thread( - () -> { - long failureTimeNano = 0; - try { - final Properties props = ConnectionStringHelper.getDefaultProperties(); - - props.setProperty( - MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, - String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); - props.setProperty( - MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, - String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); - CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); - - FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); - FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); - FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); - FAILOVER_TIMEOUT_MS.set(props, Integer.toString(EFM_FAILOVER_TIMEOUT_MS)); - PLUGINS.set(props, "failover,efm"); - - final Connection conn = - openConnectionWithRetry( - ConnectionStringHelper.getWrapperUrl( - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpoint(), - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpointPort(), - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getDefaultDbName()), - props); - LOGGER.finest("WrapperEfmFailover connection is open."); - - Thread.sleep(1000); - startLatch.countDown(); // notify that this thread is ready for work - startLatch.await( - 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test - - LOGGER.finest("WrapperEfmFailover Starting long query..."); - // Execute long query - final Statement statement = conn.createStatement(); - try (final ResultSet result = statement.executeQuery(QUERY)) { - fail("Sleep query finished, should not be possible with network downed."); - } catch (SQLException throwable) { - LOGGER.finest("WrapperEfmFailover thread exception: " + throwable); - if (throwable instanceof FailoverSuccessSQLException) { - // Calculate and add detection time - assertTrue(downtimeNano.get() > 0); - failureTimeNano = System.nanoTime() - downtimeNano.get(); - } - } - - } catch (InterruptedException interruptedException) { - // Ignore, stop the thread - } catch (Exception exception) { - fail("WrapperEfmFailover thread exception: " + exception); - } finally { - PerfStat data = new PerfStat(); - data.paramFailoverDelayMillis = sleepDelayMillis; - data.paramDriverName = - String.format( - "AWS Wrapper (%s, EFM, Failover)", - TestEnvironment.getCurrent().getCurrentDriver()); - data.reconnectTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); - LOGGER.finest("WrapperEfmFailover Collected data: " + data); - perfDataList.add(data); - LOGGER.finest( - "WrapperEfmFailover Reconnect time is " + data.reconnectTimeMillis + "ms"); - - finishLatch.countDown(); - LOGGER.finest("WrapperEfmFailover thread is completed."); - } - }); - } - - private Thread getThread_DNS( - final int sleepDelayMillis, - final AtomicLong downtimeNano, - final CountDownLatch startLatch, - final CountDownLatch finishLatch) { - - return new Thread( - () -> { - long failureTimeNano = 0; - String currentClusterIpAddress; - - try { - currentClusterIpAddress = - InetAddress.getByName( - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpoint()) - .getHostAddress(); - LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); - - Thread.sleep(1000); - startLatch.countDown(); // notify that this thread is ready for work - startLatch.await( - 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test - - String clusterIpAddress = - InetAddress.getByName( - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpoint()) - .getHostAddress(); - - long startTimeNano = System.nanoTime(); - while (clusterIpAddress.equals(currentClusterIpAddress) - && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { - Thread.sleep(1000); - clusterIpAddress = - InetAddress.getByName( - TestEnvironment.getCurrent() - .getInfo() - .getDatabaseInfo() - .getClusterEndpoint()) - .getHostAddress(); - LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); - } - - // DNS data has changed - if (!clusterIpAddress.equals(currentClusterIpAddress)) { - assertTrue(downtimeNano.get() > 0); - failureTimeNano = System.nanoTime() - downtimeNano.get(); - } - - } catch (InterruptedException interruptedException) { - // Ignore, stop the thread - } catch (Exception exception) { - fail("Failover thread exception: " + exception); - } finally { - PerfStat data = new PerfStat(); - data.paramFailoverDelayMillis = sleepDelayMillis; - data.paramDriverName = "DNS"; - data.dnsUpdateTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); - LOGGER.finest("DNS Collected data: " + data); - perfDataList.add(data); - LOGGER.finest("DNS Update time is " + data.dnsUpdateTimeMillis + "ms"); - - finishLatch.countDown(); - LOGGER.finest("DNS thread is completed."); - } - }); - } - - private Connection openConnectionWithRetry(String url, Properties props) { - Connection conn = null; - int connectCount = 0; - while (conn == null && connectCount < 10) { - try { - conn = DriverManager.getConnection(url, props); - - } catch (SQLException sqlEx) { - // ignore, try to connect again - } - connectCount++; - } - - if (conn == null) { - fail("Can't connect to " + url); - } - return conn; - } - - private void failoverCluster() throws InterruptedException { - String clusterId = TestEnvironment.getCurrent().getInfo().getRdsDbName(); - String randomNode = auroraUtil.getRandomDBClusterReaderInstanceId(clusterId); - auroraUtil.failoverClusterToTarget(clusterId, randomNode); - } - - private void ensureClusterHealthy() throws InterruptedException { - - auroraUtil.waitUntilClusterHasRightState( - TestEnvironment.getCurrent().getInfo().getRdsDbName()); - - // Always get the latest topology info with writer as first - List latestTopology = new ArrayList<>(); - - // Need to ensure that cluster details through API matches topology fetched through SQL - // Wait up to 5min - long startTimeNano = System.nanoTime(); - while ((latestTopology.size() - != TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances() - || !auroraUtil.isDBInstanceWriter(latestTopology.get(0))) - && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { - - Thread.sleep(5000); - - try { - latestTopology = auroraUtil.getAuroraInstanceIds(); - } catch (SQLException ex) { - latestTopology = new ArrayList<>(); - } - } - assertTrue( - auroraUtil.isDBInstanceWriter( - TestEnvironment.getCurrent().getInfo().getRdsDbName(), latestTopology.get(0))); - String currentWriter = latestTopology.get(0); - - // Adjust database info to reflect a current writer and to move corresponding instance to - // position 0. - TestEnvironment.getCurrent().getInfo().getDatabaseInfo().moveInstanceFirst(currentWriter); - TestEnvironment.getCurrent().getInfo().getProxyDatabaseInfo().moveInstanceFirst(currentWriter); - - auroraUtil.makeSureInstancesUp(TimeUnit.MINUTES.toSeconds(5)); - - TestAuroraHostListProvider.clearCache(); - TestPluginServiceImpl.clearHostAvailabilityCache(); - HostMonitorThreadContainer.releaseInstance(); - HostMonitorServiceImpl.closeAllMonitors(); - } - - private static Stream generateParams() { - - ArrayList args = new ArrayList<>(); - - for (int i = 1; i <= REPEAT_TIMES; i++) { - args.add(Arguments.of(10000, i)); - } - for (int i = 1; i <= REPEAT_TIMES; i++) { - args.add(Arguments.of(20000, i)); - } - for (int i = 1; i <= REPEAT_TIMES; i++) { - args.add(Arguments.of(30000, i)); - } - for (int i = 1; i <= REPEAT_TIMES; i++) { - args.add(Arguments.of(40000, i)); - } - for (int i = 1; i <= REPEAT_TIMES; i++) { - args.add(Arguments.of(50000, i)); - } - for (int i = 1; i <= REPEAT_TIMES; i++) { - args.add(Arguments.of(60000, i)); - } - - return Stream.of(args.toArray(new Arguments[0])); - } - - private static class PerfStat { - - public String paramDriverName; - public int paramFailoverDelayMillis; - public long failureDetectionTimeMillis; - public long reconnectTimeMillis; - public long dnsUpdateTimeMillis; - - public void writeHeader(Row row) { - Cell cell = row.createCell(0); - cell.setCellValue("Driver Configuration"); - cell = row.createCell(1); - cell.setCellValue("Failover Delay Millis"); - cell = row.createCell(2); - cell.setCellValue("Failure Detection Time Millis"); - cell = row.createCell(3); - cell.setCellValue("Reconnect Time Millis"); - cell = row.createCell(4); - cell.setCellValue("DNS Update Time Millis"); - } - - public void writeData(Row row) { - Cell cell = row.createCell(0); - cell.setCellValue(this.paramDriverName); - cell = row.createCell(1); - cell.setCellValue(this.paramFailoverDelayMillis); - cell = row.createCell(2); - cell.setCellValue(this.failureDetectionTimeMillis); - cell = row.createCell(3); - cell.setCellValue(this.reconnectTimeMillis); - cell = row.createCell(4); - cell.setCellValue(this.dnsUpdateTimeMillis); - } - - @Override - public String toString() { - return String.format("%s [\nparamDriverName=%s,\nparamFailoverDelayMillis=%d,\n" - + "failureDetectionTimeMillis=%d,\nreconnectTimeMillis=%d,\ndnsUpdateTimeMillis=%d ]", - super.toString(), - this.paramDriverName, - this.paramFailoverDelayMillis, - this.failureDetectionTimeMillis, - this.reconnectTimeMillis, - this.dnsUpdateTimeMillis); - } - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package integration.container.tests; +// +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.junit.jupiter.api.Assertions.fail; +// import static software.amazon.jdbc.PropertyDefinition.CONNECT_TIMEOUT; +// import static software.amazon.jdbc.PropertyDefinition.PLUGINS; +// import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_COUNT; +// import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_INTERVAL; +// import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_TIME; +// import static software.amazon.jdbc.plugin.failover.FailoverConnectionPlugin.FAILOVER_TIMEOUT_MS; +// +// import integration.TestEnvironmentFeatures; +// import integration.container.ConnectionStringHelper; +// import integration.container.TestDriverProvider; +// import integration.container.TestEnvironment; +// import integration.container.aurora.TestAuroraHostListProvider; +// import integration.container.aurora.TestPluginServiceImpl; +// import integration.container.condition.DisableOnTestFeature; +// import integration.container.condition.EnableOnTestFeature; +// import integration.util.AuroraTestUtility; +// import java.io.File; +// import java.io.FileOutputStream; +// import java.io.IOException; +// import java.net.InetAddress; +// import java.net.UnknownHostException; +// import java.sql.Connection; +// import java.sql.DriverManager; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.ArrayList; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.ConcurrentLinkedQueue; +// import java.util.concurrent.CountDownLatch; +// import java.util.concurrent.TimeUnit; +// import java.util.concurrent.atomic.AtomicLong; +// import java.util.logging.Logger; +// import java.util.stream.Collectors; +// import java.util.stream.Stream; +// import org.apache.poi.ss.usermodel.Cell; +// import org.apache.poi.ss.usermodel.Row; +// import org.apache.poi.xssf.usermodel.XSSFSheet; +// import org.apache.poi.xssf.usermodel.XSSFWorkbook; +// import org.junit.jupiter.api.MethodOrderer; +// import org.junit.jupiter.api.Order; +// import org.junit.jupiter.api.Tag; +// import org.junit.jupiter.api.TestMethodOrder; +// import org.junit.jupiter.api.TestTemplate; +// import org.junit.jupiter.api.extension.ExtendWith; +// import org.junit.jupiter.params.provider.Arguments; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; +// import software.amazon.jdbc.plugin.efm2.HostMonitorServiceImpl; +// import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; +// import software.amazon.jdbc.util.StringUtils; +// +// @TestMethodOrder(MethodOrderer.MethodName.class) +// @ExtendWith(TestDriverProvider.class) +// @EnableOnTestFeature({ +// TestEnvironmentFeatures.PERFORMANCE, +// TestEnvironmentFeatures.FAILOVER_SUPPORTED +// }) +// @DisableOnTestFeature(TestEnvironmentFeatures.RUN_DB_METRICS_ONLY) +// @Tag("advanced") +// @Order(1) +// public class AdvancedPerformanceTest { +// +// private static final Logger LOGGER = Logger.getLogger(AdvancedPerformanceTest.class.getName()); +// +// private static final String MONITORING_CONNECTION_PREFIX = "monitoring-"; +// +// private static final int REPEAT_TIMES = +// StringUtils.isNullOrEmpty(System.getenv("REPEAT_TIMES")) +// ? 5 +// : Integer.parseInt(System.getenv("REPEAT_TIMES")); +// +// private static final int TIMEOUT_SEC = 5; +// private static final int CONNECT_TIMEOUT_SEC = 5; +// private static final int EFM_FAILOVER_TIMEOUT_MS = 300000; +// private static final int EFM_FAILURE_DETECTION_TIME_MS = 30000; +// private static final int EFM_FAILURE_DETECTION_INTERVAL_MS = 5000; +// private static final int EFM_FAILURE_DETECTION_COUNT = 3; +// private static final String QUERY = "SELECT pg_sleep(600)"; // 600s -> 10min +// +// private static final ConcurrentLinkedQueue perfDataList = new ConcurrentLinkedQueue<>(); +// +// protected static final AuroraTestUtility auroraUtil = AuroraTestUtility.getUtility(); +// +// private static void doWritePerfDataToFile( +// String fileName, ConcurrentLinkedQueue dataList) throws IOException { +// +// if (dataList.isEmpty()) { +// return; +// } +// +// LOGGER.finest(() -> "File name: " + fileName); +// +// List sortedData = +// dataList.stream() +// .sorted( +// (d1, d2) -> +// d1.paramFailoverDelayMillis == d2.paramFailoverDelayMillis +// ? d1.paramDriverName.compareTo(d2.paramDriverName) +// : 0) +// .collect(Collectors.toList()); +// +// try (XSSFWorkbook workbook = new XSSFWorkbook()) { +// +// final XSSFSheet sheet = workbook.createSheet("PerformanceResults"); +// +// for (int rows = 0; rows < dataList.size(); rows++) { +// PerfStat perfStat = sortedData.get(rows); +// Row row; +// +// if (rows == 0) { +// // Header +// row = sheet.createRow(0); +// perfStat.writeHeader(row); +// } +// +// row = sheet.createRow(rows + 1); +// perfStat.writeData(row); +// } +// +// // Write to file +// final File newExcelFile = new File(fileName); +// newExcelFile.createNewFile(); +// try (FileOutputStream fileOut = new FileOutputStream(newExcelFile)) { +// workbook.write(fileOut); +// } +// } +// } +// +// @TestTemplate +// public void test_AdvancedPerformance() throws IOException { +// +// perfDataList.clear(); +// +// try { +// Stream argsStream = generateParams(); +// argsStream.forEach( +// a -> { +// try { +// ensureClusterHealthy(); +// LOGGER.finest("DB cluster is healthy."); +// ensureDnsHealthy(); +// LOGGER.finest("DNS is healthy."); +// +// Object[] args = a.get(); +// int failoverDelayTimeMillis = (int) args[0]; +// int runNumber = (int) args[1]; +// +// LOGGER.finest( +// "Iteration " +// + runNumber +// + "/" +// + REPEAT_TIMES +// + " for " +// + failoverDelayTimeMillis +// + "ms delay"); +// +// doMeasurePerformance(failoverDelayTimeMillis); +// +// } catch (InterruptedException ex) { +// throw new RuntimeException(ex); +// } catch (UnknownHostException e) { +// throw new RuntimeException(e); +// } +// }); +// +// } finally { +// doWritePerfDataToFile( +// String.format( +// "./build/reports/tests/AdvancedPerformanceResults_" +// + "Db_%s_Driver_%s_Instances_%d.xlsx", +// TestEnvironment.getCurrent().getInfo().getRequest().getDatabaseEngine(), +// TestEnvironment.getCurrent().getCurrentDriver(), +// TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances()), +// perfDataList); +// perfDataList.clear(); +// } +// } +// +// private void doMeasurePerformance(int sleepDelayMillis) throws InterruptedException { +// +// final AtomicLong downtimeNano = new AtomicLong(); +// final CountDownLatch startLatch = new CountDownLatch(5); +// final CountDownLatch finishLatch = new CountDownLatch(5); +// +// downtimeNano.set(0); +// +// final Thread failoverThread = +// getThread_Failover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); +// final Thread pgThread = +// getThread_DirectDriver(sleepDelayMillis, downtimeNano, startLatch, finishLatch); +// final Thread wrapperEfmThread = +// getThread_WrapperEfm(sleepDelayMillis, downtimeNano, startLatch, finishLatch); +// final Thread wrapperEfmFailoverThread = +// getThread_WrapperEfmFailover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); +// final Thread dnsThread = getThread_DNS(sleepDelayMillis, downtimeNano, startLatch, finishLatch); +// +// failoverThread.start(); +// pgThread.start(); +// wrapperEfmThread.start(); +// wrapperEfmFailoverThread.start(); +// dnsThread.start(); +// +// LOGGER.finest("All threads started."); +// +// finishLatch.await(5, TimeUnit.MINUTES); // wait for all threads to complete +// +// LOGGER.finest("Test is over."); +// +// assertTrue(downtimeNano.get() > 0); +// +// failoverThread.interrupt(); +// pgThread.interrupt(); +// wrapperEfmThread.interrupt(); +// wrapperEfmFailoverThread.interrupt(); +// dnsThread.interrupt(); +// } +// +// private void ensureDnsHealthy() throws UnknownHostException, InterruptedException { +// LOGGER.finest( +// "Writer is " +// + TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getInstances() +// .get(0) +// .getInstanceId()); +// final String writerIpAddress = +// InetAddress.getByName( +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getInstances() +// .get(0) +// .getHost()) +// .getHostAddress(); +// LOGGER.finest("Writer resolves to " + writerIpAddress); +// LOGGER.finest( +// "Cluster Endpoint is " +// + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()); +// String clusterIpAddress = +// InetAddress.getByName( +// TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) +// .getHostAddress(); +// LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); +// +// long startTimeNano = System.nanoTime(); +// while (!clusterIpAddress.equals(writerIpAddress) +// && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { +// Thread.sleep(1000); +// clusterIpAddress = +// InetAddress.getByName( +// TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) +// .getHostAddress(); +// LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); +// } +// +// if (!clusterIpAddress.equals(writerIpAddress)) { +// fail("DNS has stale data"); +// } +// } +// +// private Thread getThread_Failover( +// final int sleepDelayMillis, +// final AtomicLong downtimeNano, +// final CountDownLatch startLatch, +// final CountDownLatch finishLatch) { +// +// return new Thread( +// () -> { +// try { +// Thread.sleep(1000); +// startLatch.countDown(); // notify that this thread is ready for work +// startLatch.await( +// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test +// +// LOGGER.finest("Waiting " + sleepDelayMillis + "ms..."); +// Thread.sleep(sleepDelayMillis); +// LOGGER.finest("Trigger failover..."); +// +// // trigger failover +// failoverCluster(); +// downtimeNano.set(System.nanoTime()); +// LOGGER.finest("Failover is started."); +// +// } catch (InterruptedException interruptedException) { +// // Ignore, stop the thread +// } catch (Exception exception) { +// fail("Failover thread exception: " + exception); +// } finally { +// finishLatch.countDown(); +// LOGGER.finest("Failover thread is completed."); +// } +// }); +// } +// +// private Thread getThread_DirectDriver( +// final int sleepDelayMillis, +// final AtomicLong downtimeNano, +// final CountDownLatch startLatch, +// final CountDownLatch finishLatch) { +// +// return new Thread( +// () -> { +// long failureTimeNano = 0; +// try { +// // DB_CONN_STR_PREFIX +// final Properties props = ConnectionStringHelper.getDefaultProperties(); +// final Connection conn = +// openConnectionWithRetry( +// ConnectionStringHelper.getUrl( +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpoint(), +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpointPort(), +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getDefaultDbName()), +// props); +// LOGGER.finest("DirectDriver connection is open."); +// +// Thread.sleep(1000); +// startLatch.countDown(); // notify that this thread is ready for work +// startLatch.await( +// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test +// +// LOGGER.finest("DirectDriver Starting long query..."); +// // Execute long query +// final Statement statement = conn.createStatement(); +// try (final ResultSet result = statement.executeQuery(QUERY)) { +// fail("Sleep query finished, should not be possible with network downed."); +// } catch (SQLException throwable) { // Catching executing query +// LOGGER.finest("DirectDriver thread exception: " + throwable); +// // Calculate and add detection time +// assertTrue(downtimeNano.get() > 0); +// failureTimeNano = System.nanoTime() - downtimeNano.get(); +// } +// +// } catch (InterruptedException interruptedException) { +// // Ignore, stop the thread +// } catch (Exception exception) { +// fail("PG thread exception: " + exception); +// } finally { +// PerfStat data = new PerfStat(); +// data.paramFailoverDelayMillis = sleepDelayMillis; +// data.paramDriverName = +// "DirectDriver - " + TestEnvironment.getCurrent().getCurrentDriver(); +// data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); +// LOGGER.finest("DirectDriver Collected data: " + data); +// perfDataList.add(data); +// LOGGER.finest( +// "DirectDriver Failure detection time is " + data.failureDetectionTimeMillis + "ms"); +// +// finishLatch.countDown(); +// LOGGER.finest("DirectDriver thread is completed."); +// } +// }); +// } +// +// private Thread getThread_WrapperEfm( +// final int sleepDelayMillis, +// final AtomicLong downtimeNano, +// final CountDownLatch startLatch, +// final CountDownLatch finishLatch) { +// +// return new Thread( +// () -> { +// long failureTimeNano = 0; +// try { +// final Properties props = ConnectionStringHelper.getDefaultProperties(); +// +// props.setProperty( +// MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, +// String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); +// props.setProperty( +// MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, +// String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); +// CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); +// +// FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); +// FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_INTERVAL_MS)); +// FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); +// PLUGINS.set(props, "efm"); +// +// final Connection conn = +// openConnectionWithRetry( +// ConnectionStringHelper.getWrapperUrl( +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpoint(), +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpointPort(), +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getDefaultDbName()), +// props); +// LOGGER.finest("WrapperEfm connection is open."); +// +// Thread.sleep(1000); +// startLatch.countDown(); // notify that this thread is ready for work +// startLatch.await( +// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test +// +// LOGGER.finest("WrapperEfm Starting long query..."); +// // Execute long query +// final Statement statement = conn.createStatement(); +// try (final ResultSet result = statement.executeQuery(QUERY)) { +// fail("Sleep query finished, should not be possible with network downed."); +// } catch (SQLException throwable) { // Catching executing query +// LOGGER.finest("WrapperEfm thread exception: " + throwable); +// +// // Calculate and add detection time +// assertTrue(downtimeNano.get() > 0); +// failureTimeNano = System.nanoTime() - downtimeNano.get(); +// } +// +// } catch (InterruptedException interruptedException) { +// // Ignore, stop the thread +// } catch (Exception exception) { +// fail("WrapperEfm thread exception: " + exception); +// } finally { +// PerfStat data = new PerfStat(); +// data.paramFailoverDelayMillis = sleepDelayMillis; +// data.paramDriverName = +// String.format( +// "AWS Wrapper (%s, EFM)", TestEnvironment.getCurrent().getCurrentDriver()); +// data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); +// LOGGER.finest("WrapperEfm Collected data: " + data); +// perfDataList.add(data); +// LOGGER.finest( +// "WrapperEfm Failure detection time is " + data.failureDetectionTimeMillis + "ms"); +// +// finishLatch.countDown(); +// LOGGER.finest("WrapperEfm thread is completed."); +// } +// }); +// } +// +// private Thread getThread_WrapperEfmFailover( +// final int sleepDelayMillis, +// final AtomicLong downtimeNano, +// final CountDownLatch startLatch, +// final CountDownLatch finishLatch) { +// +// return new Thread( +// () -> { +// long failureTimeNano = 0; +// try { +// final Properties props = ConnectionStringHelper.getDefaultProperties(); +// +// props.setProperty( +// MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, +// String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); +// props.setProperty( +// MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, +// String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); +// CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); +// +// FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); +// FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); +// FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); +// FAILOVER_TIMEOUT_MS.set(props, Integer.toString(EFM_FAILOVER_TIMEOUT_MS)); +// PLUGINS.set(props, "failover,efm"); +// +// final Connection conn = +// openConnectionWithRetry( +// ConnectionStringHelper.getWrapperUrl( +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpoint(), +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpointPort(), +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getDefaultDbName()), +// props); +// LOGGER.finest("WrapperEfmFailover connection is open."); +// +// Thread.sleep(1000); +// startLatch.countDown(); // notify that this thread is ready for work +// startLatch.await( +// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test +// +// LOGGER.finest("WrapperEfmFailover Starting long query..."); +// // Execute long query +// final Statement statement = conn.createStatement(); +// try (final ResultSet result = statement.executeQuery(QUERY)) { +// fail("Sleep query finished, should not be possible with network downed."); +// } catch (SQLException throwable) { +// LOGGER.finest("WrapperEfmFailover thread exception: " + throwable); +// if (throwable instanceof FailoverSuccessSQLException) { +// // Calculate and add detection time +// assertTrue(downtimeNano.get() > 0); +// failureTimeNano = System.nanoTime() - downtimeNano.get(); +// } +// } +// +// } catch (InterruptedException interruptedException) { +// // Ignore, stop the thread +// } catch (Exception exception) { +// fail("WrapperEfmFailover thread exception: " + exception); +// } finally { +// PerfStat data = new PerfStat(); +// data.paramFailoverDelayMillis = sleepDelayMillis; +// data.paramDriverName = +// String.format( +// "AWS Wrapper (%s, EFM, Failover)", +// TestEnvironment.getCurrent().getCurrentDriver()); +// data.reconnectTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); +// LOGGER.finest("WrapperEfmFailover Collected data: " + data); +// perfDataList.add(data); +// LOGGER.finest( +// "WrapperEfmFailover Reconnect time is " + data.reconnectTimeMillis + "ms"); +// +// finishLatch.countDown(); +// LOGGER.finest("WrapperEfmFailover thread is completed."); +// } +// }); +// } +// +// private Thread getThread_DNS( +// final int sleepDelayMillis, +// final AtomicLong downtimeNano, +// final CountDownLatch startLatch, +// final CountDownLatch finishLatch) { +// +// return new Thread( +// () -> { +// long failureTimeNano = 0; +// String currentClusterIpAddress; +// +// try { +// currentClusterIpAddress = +// InetAddress.getByName( +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpoint()) +// .getHostAddress(); +// LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); +// +// Thread.sleep(1000); +// startLatch.countDown(); // notify that this thread is ready for work +// startLatch.await( +// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test +// +// String clusterIpAddress = +// InetAddress.getByName( +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpoint()) +// .getHostAddress(); +// +// long startTimeNano = System.nanoTime(); +// while (clusterIpAddress.equals(currentClusterIpAddress) +// && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { +// Thread.sleep(1000); +// clusterIpAddress = +// InetAddress.getByName( +// TestEnvironment.getCurrent() +// .getInfo() +// .getDatabaseInfo() +// .getClusterEndpoint()) +// .getHostAddress(); +// LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); +// } +// +// // DNS data has changed +// if (!clusterIpAddress.equals(currentClusterIpAddress)) { +// assertTrue(downtimeNano.get() > 0); +// failureTimeNano = System.nanoTime() - downtimeNano.get(); +// } +// +// } catch (InterruptedException interruptedException) { +// // Ignore, stop the thread +// } catch (Exception exception) { +// fail("Failover thread exception: " + exception); +// } finally { +// PerfStat data = new PerfStat(); +// data.paramFailoverDelayMillis = sleepDelayMillis; +// data.paramDriverName = "DNS"; +// data.dnsUpdateTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); +// LOGGER.finest("DNS Collected data: " + data); +// perfDataList.add(data); +// LOGGER.finest("DNS Update time is " + data.dnsUpdateTimeMillis + "ms"); +// +// finishLatch.countDown(); +// LOGGER.finest("DNS thread is completed."); +// } +// }); +// } +// +// private Connection openConnectionWithRetry(String url, Properties props) { +// Connection conn = null; +// int connectCount = 0; +// while (conn == null && connectCount < 10) { +// try { +// conn = DriverManager.getConnection(url, props); +// +// } catch (SQLException sqlEx) { +// // ignore, try to connect again +// } +// connectCount++; +// } +// +// if (conn == null) { +// fail("Can't connect to " + url); +// } +// return conn; +// } +// +// private void failoverCluster() throws InterruptedException { +// String clusterId = TestEnvironment.getCurrent().getInfo().getRdsDbName(); +// String randomNode = auroraUtil.getRandomDBClusterReaderInstanceId(clusterId); +// auroraUtil.failoverClusterToTarget(clusterId, randomNode); +// } +// +// private void ensureClusterHealthy() throws InterruptedException { +// +// auroraUtil.waitUntilClusterHasRightState( +// TestEnvironment.getCurrent().getInfo().getRdsDbName()); +// +// // Always get the latest topology info with writer as first +// List latestTopology = new ArrayList<>(); +// +// // Need to ensure that cluster details through API matches topology fetched through SQL +// // Wait up to 5min +// long startTimeNano = System.nanoTime(); +// while ((latestTopology.size() +// != TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances() +// || !auroraUtil.isDBInstanceWriter(latestTopology.get(0))) +// && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { +// +// Thread.sleep(5000); +// +// try { +// latestTopology = auroraUtil.getAuroraInstanceIds(); +// } catch (SQLException ex) { +// latestTopology = new ArrayList<>(); +// } +// } +// assertTrue( +// auroraUtil.isDBInstanceWriter( +// TestEnvironment.getCurrent().getInfo().getRdsDbName(), latestTopology.get(0))); +// String currentWriter = latestTopology.get(0); +// +// // Adjust database info to reflect a current writer and to move corresponding instance to +// // position 0. +// TestEnvironment.getCurrent().getInfo().getDatabaseInfo().moveInstanceFirst(currentWriter); +// TestEnvironment.getCurrent().getInfo().getProxyDatabaseInfo().moveInstanceFirst(currentWriter); +// +// auroraUtil.makeSureInstancesUp(TimeUnit.MINUTES.toSeconds(5)); +// +// TestAuroraHostListProvider.clearCache(); +// TestPluginServiceImpl.clearHostAvailabilityCache(); +// HostMonitorThreadContainer.releaseInstance(); +// HostMonitorServiceImpl.closeAllMonitors(); +// } +// +// private static Stream generateParams() { +// +// ArrayList args = new ArrayList<>(); +// +// for (int i = 1; i <= REPEAT_TIMES; i++) { +// args.add(Arguments.of(10000, i)); +// } +// for (int i = 1; i <= REPEAT_TIMES; i++) { +// args.add(Arguments.of(20000, i)); +// } +// for (int i = 1; i <= REPEAT_TIMES; i++) { +// args.add(Arguments.of(30000, i)); +// } +// for (int i = 1; i <= REPEAT_TIMES; i++) { +// args.add(Arguments.of(40000, i)); +// } +// for (int i = 1; i <= REPEAT_TIMES; i++) { +// args.add(Arguments.of(50000, i)); +// } +// for (int i = 1; i <= REPEAT_TIMES; i++) { +// args.add(Arguments.of(60000, i)); +// } +// +// return Stream.of(args.toArray(new Arguments[0])); +// } +// +// private static class PerfStat { +// +// public String paramDriverName; +// public int paramFailoverDelayMillis; +// public long failureDetectionTimeMillis; +// public long reconnectTimeMillis; +// public long dnsUpdateTimeMillis; +// +// public void writeHeader(Row row) { +// Cell cell = row.createCell(0); +// cell.setCellValue("Driver Configuration"); +// cell = row.createCell(1); +// cell.setCellValue("Failover Delay Millis"); +// cell = row.createCell(2); +// cell.setCellValue("Failure Detection Time Millis"); +// cell = row.createCell(3); +// cell.setCellValue("Reconnect Time Millis"); +// cell = row.createCell(4); +// cell.setCellValue("DNS Update Time Millis"); +// } +// +// public void writeData(Row row) { +// Cell cell = row.createCell(0); +// cell.setCellValue(this.paramDriverName); +// cell = row.createCell(1); +// cell.setCellValue(this.paramFailoverDelayMillis); +// cell = row.createCell(2); +// cell.setCellValue(this.failureDetectionTimeMillis); +// cell = row.createCell(3); +// cell.setCellValue(this.reconnectTimeMillis); +// cell = row.createCell(4); +// cell.setCellValue(this.dnsUpdateTimeMillis); +// } +// +// @Override +// public String toString() { +// return String.format("%s [\nparamDriverName=%s,\nparamFailoverDelayMillis=%d,\n" +// + "failureDetectionTimeMillis=%d,\nreconnectTimeMillis=%d,\ndnsUpdateTimeMillis=%d ]", +// super.toString(), +// this.paramDriverName, +// this.paramFailoverDelayMillis, +// this.failureDetectionTimeMillis, +// this.reconnectTimeMillis, +// this.dnsUpdateTimeMillis); +// } +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java index 355276ad3..45ed19fd4 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java @@ -1,970 +1,970 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.locks.ReentrantLock; -import java.util.logging.Logger; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.mock.TestPluginOne; -import software.amazon.jdbc.mock.TestPluginThree; -import software.amazon.jdbc.mock.TestPluginThrowException; -import software.amazon.jdbc.mock.TestPluginTwo; -import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; -import software.amazon.jdbc.plugin.DefaultConnectionPlugin; -import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; -import software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin; -import software.amazon.jdbc.profile.ConfigurationProfile; -import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.wrapper.ConnectionWrapper; - -public class ConnectionPluginManagerTests { - - private static final Logger LOGGER = Logger.getLogger(ConnectionPluginManagerTests.class.getName()); - - @Mock JdbcCallable mockSqlFunction; - @Mock ConnectionProvider mockConnectionProvider; - @Mock ConnectionWrapper mockConnectionWrapper; - @Mock TelemetryFactory mockTelemetryFactory; - @Mock TelemetryContext mockTelemetryContext; - @Mock FullServicesContainer mockServicesContainer; - @Mock PluginService mockPluginService; - @Mock PluginManagerService mockPluginManagerService; - @Mock TargetDriverDialect mockTargetDriverDialect; - - ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); - - private AutoCloseable closeable; - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - } - - @BeforeEach - void init() { - closeable = MockitoAnnotations.openMocks(this); - when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); - when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); - when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); - } - - @Test - public void testExecuteJdbcCallA() throws Exception { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThree(calls)); - - final Properties testProperties = new Properties(); - - final Object[] testArgs = new Object[] {10, "arg2", 3.33}; - - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Object result = - target.execute( - String.class, - Exception.class, - Connection.class, - JdbcMethod.BLOB_LENGTH, - () -> { - calls.add("targetCall"); - return "resulTestValue"; - }, - testArgs); - - assertEquals("resulTestValue", result); - - assertEquals(7, calls.size()); - assertEquals("TestPluginOne:before", calls.get(0)); - assertEquals("TestPluginTwo:before", calls.get(1)); - assertEquals("TestPluginThree:before", calls.get(2)); - assertEquals("targetCall", calls.get(3)); - assertEquals("TestPluginThree:after", calls.get(4)); - assertEquals("TestPluginTwo:after", calls.get(5)); - assertEquals("TestPluginOne:after", calls.get(6)); - } - - @Test - public void testExecuteJdbcCallB() throws Exception { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThree(calls)); - - final Properties testProperties = new Properties(); - - final Object[] testArgs = new Object[] {10, "arg2", 3.33}; - - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Object result = - target.execute( - String.class, - Exception.class, - Connection.class, - JdbcMethod.BLOB_POSITION, - () -> { - calls.add("targetCall"); - return "resulTestValue"; - }, - testArgs); - - assertEquals("resulTestValue", result); - - assertEquals(5, calls.size()); - assertEquals("TestPluginOne:before", calls.get(0)); - assertEquals("TestPluginTwo:before", calls.get(1)); - assertEquals("targetCall", calls.get(2)); - assertEquals("TestPluginTwo:after", calls.get(3)); - assertEquals("TestPluginOne:after", calls.get(4)); - } - - @Test - public void testExecuteJdbcCallC() throws Exception { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThree(calls)); - - final Properties testProperties = new Properties(); - - final Object[] testArgs = new Object[] {10, "arg2", 3.33}; - - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Object result = - target.execute( - String.class, - Exception.class, - Connection.class, - JdbcMethod.BLOB_GETBYTES, - () -> { - calls.add("targetCall"); - return "resulTestValue"; - }, - testArgs); - - assertEquals("resulTestValue", result); - - assertEquals(3, calls.size()); - assertEquals("TestPluginOne:before", calls.get(0)); - assertEquals("targetCall", calls.get(1)); - assertEquals("TestPluginOne:after", calls.get(2)); - } - - @Test - public void testConnect() throws Exception { - - final Connection expectedConnection = mock(Connection.class); - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThree(calls, expectedConnection)); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Connection conn = target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, - true, null); - - assertEquals(expectedConnection, conn); - assertEquals(4, calls.size()); - assertEquals("TestPluginOne:before connect", calls.get(0)); - assertEquals("TestPluginThree:before connect", calls.get(1)); - assertEquals("TestPluginThree:connection", calls.get(2)); - assertEquals("TestPluginOne:after connect", calls.get(3)); - } - - @Test - public void testConnectWithSkipPlugin() throws Exception { - - final Connection expectedConnection = mock(Connection.class); - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - final ConnectionPlugin pluginOne = new TestPluginOne(calls); - testPlugins.add(pluginOne); - final ConnectionPlugin pluginTwo = new TestPluginTwo(calls); - testPlugins.add(pluginTwo); - final ConnectionPlugin pluginThree = new TestPluginThree(calls, expectedConnection); - testPlugins.add(pluginThree); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Connection conn = target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, - true, pluginOne); - - assertEquals(expectedConnection, conn); - assertEquals(2, calls.size()); - assertEquals("TestPluginThree:before connect", calls.get(0)); - assertEquals("TestPluginThree:connection", calls.get(1)); - } - - @Test - public void testForceConnect() throws Exception { - - final Connection expectedConnection = mock(Connection.class); - final ArrayList calls = new ArrayList<>(); - final ArrayList testPlugins = new ArrayList<>(); - - // TestPluginOne is not an AuthenticationConnectionPlugin. - testPlugins.add(new TestPluginOne(calls)); - - // TestPluginTwo is an AuthenticationConnectionPlugin, but it's not subscribed to "forceConnect" method. - testPlugins.add(new TestPluginTwo(calls)); - - // TestPluginThree is an AuthenticationConnectionPlugin, and it's subscribed to "forceConnect" method. - testPlugins.add(new TestPluginThree(calls, expectedConnection)); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Connection conn = target.forceConnect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, - true, - null); - - // Expecting only TestPluginThree to participate in forceConnect(). - assertEquals(expectedConnection, conn); - assertEquals(4, calls.size()); - assertEquals("TestPluginOne:before forceConnect", calls.get(0)); - assertEquals("TestPluginThree:before forceConnect", calls.get(1)); - assertEquals("TestPluginThree:forced connection", calls.get(2)); - assertEquals("TestPluginOne:after forceConnect", calls.get(3)); - } - - @Test - public void testConnectWithSQLExceptionBefore() { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThrowException(calls, SQLException.class, true)); - testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - assertThrows( - SQLException.class, - () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); - - assertEquals(2, calls.size()); - assertEquals("TestPluginOne:before connect", calls.get(0)); - assertEquals("TestPluginThrowException:before", calls.get(1)); - } - - @Test - public void testConnectWithSQLExceptionAfter() { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThrowException(calls, SQLException.class, false)); - testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - assertThrows( - SQLException.class, - () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); - - assertEquals(5, calls.size()); - assertEquals("TestPluginOne:before connect", calls.get(0)); - assertEquals("TestPluginThrowException:before", calls.get(1)); - assertEquals("TestPluginThree:before connect", calls.get(2)); - assertEquals("TestPluginThree:connection", calls.get(3)); - assertEquals("TestPluginThrowException:after", calls.get(4)); - } - - @Test - public void testConnectWithUnexpectedExceptionBefore() { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, true)); - testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Exception ex = - assertThrows( - IllegalArgumentException.class, - () -> target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); - - assertEquals(2, calls.size()); - assertEquals("TestPluginOne:before connect", calls.get(0)); - assertEquals("TestPluginThrowException:before", calls.get(1)); - } - - @Test - public void testConnectWithUnexpectedExceptionAfter() { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, false)); - testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - - final Exception ex = - assertThrows( - IllegalArgumentException.class, - () -> target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); - - assertEquals(5, calls.size()); - assertEquals("TestPluginOne:before connect", calls.get(0)); - assertEquals("TestPluginThrowException:before", calls.get(1)); - assertEquals("TestPluginThree:before connect", calls.get(2)); - assertEquals("TestPluginThree:connection", calls.get(3)); - assertEquals("TestPluginThrowException:after", calls.get(4)); - } - - @Test - public void testExecuteCachedJdbcCallA() throws Exception { - - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThree(calls)); - - final Properties testProperties = new Properties(); - - final Object[] testArgs = new Object[] {10, "arg2", 3.33}; - - final ConnectionPluginManager target = Mockito.spy( - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); - - Object result = - target.execute( - String.class, - Exception.class, - Connection.class, - JdbcMethod.BLOB_LENGTH, - () -> { - calls.add("targetCall"); - return "resulTestValue"; - }, - testArgs); - - assertEquals("resulTestValue", result); - - // The method has been called just once to generate a final lambda and cache it. - verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); - - assertEquals(7, calls.size()); - assertEquals("TestPluginOne:before", calls.get(0)); - assertEquals("TestPluginTwo:before", calls.get(1)); - assertEquals("TestPluginThree:before", calls.get(2)); - assertEquals("targetCall", calls.get(3)); - assertEquals("TestPluginThree:after", calls.get(4)); - assertEquals("TestPluginTwo:after", calls.get(5)); - assertEquals("TestPluginOne:after", calls.get(6)); - - calls.clear(); - - result = - target.execute( - String.class, - Exception.class, - Connection.class, - JdbcMethod.BLOB_LENGTH, - () -> { - calls.add("targetCall"); - return "anotherResulTestValue"; - }, - testArgs); - - assertEquals("anotherResulTestValue", result); - - // No additional calls to this method occurred. It's still been called once. - verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); - - assertEquals(7, calls.size()); - assertEquals("TestPluginOne:before", calls.get(0)); - assertEquals("TestPluginTwo:before", calls.get(1)); - assertEquals("TestPluginThree:before", calls.get(2)); - assertEquals("targetCall", calls.get(3)); - assertEquals("TestPluginThree:after", calls.get(4)); - assertEquals("TestPluginTwo:after", calls.get(5)); - assertEquals("TestPluginOne:after", calls.get(6)); - } - - @Test - public void testForceConnectCachedJdbcCallForceConnect() throws Exception { - - final ArrayList calls = new ArrayList<>(); - final Connection mockConnection = mock(Connection.class); - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThree(calls, mockConnection)); - - final HostSpec testHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("test-instance").build(); - - final Properties testProperties = new Properties(); - - final ConnectionPluginManager target = Mockito.spy( - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); - - Object result = target.forceConnect( - "any", - testHostSpec, - testProperties, - true, - null); - - assertEquals(mockConnection, result); - - // The method has been called just once to generate a final lambda and cache it. - verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); - - assertEquals(4, calls.size()); - assertEquals("TestPluginOne:before forceConnect", calls.get(0)); - assertEquals("TestPluginThree:before forceConnect", calls.get(1)); - assertEquals("TestPluginThree:forced connection", calls.get(2)); - assertEquals("TestPluginOne:after forceConnect", calls.get(3)); - - calls.clear(); - - result = target.forceConnect( - "any", - testHostSpec, - testProperties, - true, - null); - - assertEquals(mockConnection, result); - - // No additional calls to this method occurred. It's still been called once. - verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); - - assertEquals(4, calls.size()); - assertEquals("TestPluginOne:before forceConnect", calls.get(0)); - assertEquals("TestPluginThree:before forceConnect", calls.get(1)); - assertEquals("TestPluginThree:forced connection", calls.get(2)); - assertEquals("TestPluginOne:after forceConnect", calls.get(3)); - } - - @Test - public void testExecuteAgainstOldConnection() throws Exception { - final ArrayList calls = new ArrayList<>(); - - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(calls)); - testPlugins.add(new TestPluginTwo(calls)); - testPlugins.add(new TestPluginThree(calls)); - - final Properties testProperties = new Properties(); - - final Connection mockOldConnection = mock(Connection.class); - final Connection mockCurrentConnection = mock(Connection.class); - final Statement mockOldStatement = mock(Statement.class); - final ResultSet mockOldResultSet = mock(ResultSet.class); - - when(mockPluginService.getCurrentConnection()).thenReturn(mockCurrentConnection); - when(mockOldStatement.getConnection()).thenReturn(mockOldConnection); - when(mockOldResultSet.getStatement()).thenReturn(mockOldStatement); - - final ConnectionPluginManager target = - new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, - mockPluginService, mockTelemetryFactory); - - assertThrows(SQLException.class, - () -> target.execute(String.class, Exception.class, mockOldConnection, - JdbcMethod.CALLABLESTATEMENT_GETCONNECTION, () -> "result", null)); - assertThrows(SQLException.class, - () -> target.execute(String.class, Exception.class, mockOldStatement, - JdbcMethod.CALLABLESTATEMENT_GETMORERESULTS, () -> "result", null)); - assertThrows(SQLException.class, - () -> target.execute(String.class, Exception.class, mockOldResultSet, - JdbcMethod.RESULTSET_GETSTATEMENT, () -> "result", null)); - - assertDoesNotThrow( - () -> target.execute(Void.class, SQLException.class, mockOldConnection, - JdbcMethod.CONNECTION_CLOSE, mockSqlFunction, - null)); - assertDoesNotThrow( - () -> target.execute(Void.class, SQLException.class, mockOldConnection, - JdbcMethod.CONNECTION_ABORT, mockSqlFunction, - null)); - assertDoesNotThrow( - () -> target.execute(Void.class, SQLException.class, mockOldStatement, - JdbcMethod.STATEMENT_CLOSE, mockSqlFunction, - null)); - assertDoesNotThrow( - () -> target.execute(Void.class, SQLException.class, mockOldResultSet, - JdbcMethod.RESULTSET_CLOSE, mockSqlFunction, - null)); - } - - @Test - public void testDefaultPlugins() throws SQLException { - final Properties testProperties = new Properties(); - - final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( - mockConnectionProvider, - null, - mockConnectionWrapper, - mockTelemetryFactory)); - target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); - - assertEquals(4, target.plugins.size()); - assertEquals(AuroraConnectionTrackerPlugin.class, target.plugins.get(0).getClass()); - assertEquals(software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin.class, - target.plugins.get(1).getClass()); - assertEquals(HostMonitoringConnectionPlugin.class, target.plugins.get(2).getClass()); - assertEquals(DefaultConnectionPlugin.class, target.plugins.get(3).getClass()); - } - - @Test - public void testNoWrapperPlugins() throws SQLException { - final Properties testProperties = new Properties(); - testProperties.setProperty(PropertyDefinition.PLUGINS.name, ""); - - final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( - mockConnectionProvider, - null, - mockConnectionWrapper, - mockTelemetryFactory)); - target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); - - assertEquals(1, target.plugins.size()); - } - - @Test - public void testOverridingDefaultPluginsWithPluginCodes() throws SQLException { - final Properties testProperties = new Properties(); - testProperties.setProperty("wrapperPlugins", "logQuery"); - - final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( - mockConnectionProvider, - null, - mockConnectionWrapper, - mockTelemetryFactory)); - target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); - - assertEquals(2, target.plugins.size()); - assertEquals(LogQueryConnectionPlugin.class, target.plugins.get(0).getClass()); - assertEquals(DefaultConnectionPlugin.class, target.plugins.get(1).getClass()); - } - - @Test - public void testTwoConnectionsDoNotBlockOneAnother() throws Exception { - - final Properties testProperties = new Properties(); - final ArrayList testPlugins = new ArrayList<>(); - testPlugins.add(new TestPluginOne(new ArrayList<>())); - - final ConnectionProvider mockConnectionProvider1 = Mockito.mock(ConnectionProvider.class); - final ConnectionWrapper mockConnectionWrapper1 = Mockito.mock(ConnectionWrapper.class); - final PluginService mockPluginService1 = Mockito.mock(PluginService.class); - final TelemetryFactory mockTelemetryFactory1 = Mockito.mock(TelemetryFactory.class); - final Object object1 = new Object(); - when(mockPluginService1.getTelemetryFactory()).thenReturn(mockTelemetryFactory1); - when(mockTelemetryFactory1.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory1.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - - final ConnectionPluginManager pluginManager1 = - new ConnectionPluginManager(mockConnectionProvider1, - null, testProperties, testPlugins, mockConnectionWrapper1, - mockPluginService1, mockTelemetryFactory1); - - final ConnectionProvider mockConnectionProvider2 = Mockito.mock(ConnectionProvider.class); - final ConnectionWrapper mockConnectionWrapper2 = Mockito.mock(ConnectionWrapper.class); - final PluginService mockPluginService2 = Mockito.mock(PluginService.class); - final TelemetryFactory mockTelemetryFactory2 = Mockito.mock(TelemetryFactory.class); - final Object object2 = new Object(); - when(mockPluginService2.getTelemetryFactory()).thenReturn(mockTelemetryFactory2); - when(mockTelemetryFactory2.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory2.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - - final ConnectionPluginManager pluginManager2 = - new ConnectionPluginManager(mockConnectionProvider2, - null, testProperties, testPlugins, mockConnectionWrapper2, - mockPluginService2, mockTelemetryFactory2); - - // Imaginary database resource is considered "locked" when latch is 0 - final CountDownLatch waitForDbResourceLocked = new CountDownLatch(1); - final ReentrantLock dbResourceLock = new ReentrantLock(); - final CountDownLatch waitForReleaseDbResourceToProceed = new CountDownLatch(1); - final AtomicBoolean dbResourceReleased = new AtomicBoolean(false); - final AtomicBoolean acquireDbResourceLockSuccessful = new AtomicBoolean(false); - - CompletableFuture.allOf( - - // Thread 1 - CompletableFuture.runAsync(() -> { - - LOGGER.info("thread-1: started"); - - WrapperUtils.executeWithPlugins( - Integer.class, - pluginManager1, - object1, - JdbcMethod.BLOB_POSITION, // any JdbcMethod that locks connection - () -> { - dbResourceLock.lock(); - waitForDbResourceLocked.countDown(); - LOGGER.info("thread-1: locked"); - return 1; - }); - - LOGGER.info("thread-1: waiting for thread-2"); - try { - waitForReleaseDbResourceToProceed.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - LOGGER.info("thread-1: continue"); - - WrapperUtils.executeWithPlugins( - Integer.class, - pluginManager1, - object1, - JdbcMethod.BLOB_TRUNCATE, // any JdbcMethod that locks connection - () -> { - dbResourceLock.unlock(); - dbResourceReleased.set(true); - LOGGER.info("thread-1: unlocked"); - return 1; - }); - LOGGER.info("thread-1: completed"); - }), - - // Thread 2 - CompletableFuture.runAsync(() -> { - - LOGGER.info("thread-2: started"); - LOGGER.info("thread-2: waiting for thread-1"); - try { - waitForDbResourceLocked.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - LOGGER.info("thread-2: continue"); - - WrapperUtils.executeWithPlugins( - Integer.class, - pluginManager2, - object2, - JdbcMethod.BLOB_LENGTH, // any JdbcMethod that locks connection - () -> { - waitForReleaseDbResourceToProceed.countDown(); - LOGGER.info("thread-2: try to acquire a lock"); - try { - acquireDbResourceLockSuccessful.set(dbResourceLock.tryLock(5, TimeUnit.SECONDS)); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - return 1; - }); - LOGGER.info("thread-2: completed"); - }) - ).join(); - - assertTrue(dbResourceReleased.get()); - assertTrue(acquireDbResourceLockSuccessful.get()); - } - - @Test - public void testGetHostSpecByStrategy_givenPluginWithNoSubscriptions_thenThrowsSqlException() throws SQLException { - final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); - when(mockPlugin.getSubscribedMethods()).thenReturn(Collections.emptySet()); - when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); - - final List testPlugins = Collections.singletonList(mockPlugin); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, - mockPluginService, mockTelemetryFactory); - - final HostRole inputHostRole = HostRole.WRITER; - final String inputStrategy = "someStrategy"; - - assertThrows( - SQLException.class, - () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); - } - - @Test - public void testGetHostSpecByStrategy_givenPluginWithDiffSubscription_thenThrowsSqlException() throws SQLException { - final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); - when(mockPlugin.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); - when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); - - final List testPlugins = Collections.singletonList(mockPlugin); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, - mockPluginService, mockTelemetryFactory); - - final HostRole inputHostRole = HostRole.WRITER; - final String inputStrategy = "someStrategy"; - - assertThrows( - SQLException.class, - () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); - } - - @Test - public void testGetHostSpecByStrategy_givenUnsupportedPlugin_thenThrowsSqlException() throws SQLException { - final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); - when(mockPlugin.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); - when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); - - final List testPlugins = Collections.singletonList(mockPlugin); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, - mockPluginService, mockTelemetryFactory); - - final HostRole inputHostRole = HostRole.WRITER; - final String inputStrategy = "someStrategy"; - - assertThrows( - SQLException.class, - () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); - } - - @Test - public void testGetHostSpecByStrategy_givenSupportedSubscribedPlugin_thenThrowsSqlException() throws SQLException { - final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); - - when(mockPlugin.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); - - final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("expected-instance").build(); - when(mockPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); - - final List testPlugins = Collections.singletonList(mockPlugin); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, - mockPluginService, mockTelemetryFactory); - - final HostRole inputHostRole = HostRole.WRITER; - final String inputStrategy = "someStrategy"; - final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); - - verify(mockPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); - assertEquals(expectedHostSpec, actualHostSpec); - } - - @Test - public void testGetHostSpecByStrategy_givenMultiplePlugins() throws SQLException { - final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); - final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); - final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); - final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); - final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); - - final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, - unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); - - when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); - when(unsubscribedPlugin1.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); - when(unsupportedSubscribedPlugin0.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); - when(unsupportedSubscribedPlugin1.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); - when(supportedSubscribedPlugin.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); - - when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); - when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); - when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any())) - .thenThrow(new UnsupportedOperationException()); - when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any())) - .thenThrow(new UnsupportedOperationException()); - - final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("expected-instance").build(); - when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, - mockPluginService, mockTelemetryFactory); - - final HostRole inputHostRole = HostRole.WRITER; - final String inputStrategy = "someStrategy"; - final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); - - verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); - assertEquals(expectedHostSpec, actualHostSpec); - } - - @Test - public void testGetHostSpecByStrategy_givenInputHostsAndMultiplePlugins() throws SQLException { - final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); - final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); - final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); - final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); - final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); - - final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, - unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); - - when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); - when(unsubscribedPlugin1.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); - when(unsupportedSubscribedPlugin0.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); - when(unsupportedSubscribedPlugin1.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); - when(supportedSubscribedPlugin.getSubscribedMethods()) - .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); - - when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); - when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); - when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())) - .thenThrow(new UnsupportedOperationException()); - when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())) - .thenThrow(new UnsupportedOperationException()); - - final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("expected-instance").build(); - when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any(), any())).thenReturn(expectedHostSpec); - - final Properties testProperties = new Properties(); - final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, - null, testProperties, testPlugins, mockConnectionWrapper, - mockPluginService, mockTelemetryFactory); - - final List inputHosts = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("expected-instance").build()); - final HostRole inputHostRole = HostRole.WRITER; - final String inputStrategy = "someStrategy"; - final HostSpec actualHostSpec = - connectionPluginManager.getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); - - verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); - assertEquals(expectedHostSpec, actualHostSpec); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc; +// +// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.mock; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.HashSet; +// import java.util.List; +// import java.util.Properties; +// import java.util.concurrent.CompletableFuture; +// import java.util.concurrent.CountDownLatch; +// import java.util.concurrent.TimeUnit; +// import java.util.concurrent.atomic.AtomicBoolean; +// import java.util.concurrent.locks.ReentrantLock; +// import java.util.logging.Logger; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.mock.TestPluginOne; +// import software.amazon.jdbc.mock.TestPluginThree; +// import software.amazon.jdbc.mock.TestPluginThrowException; +// import software.amazon.jdbc.mock.TestPluginTwo; +// import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; +// import software.amazon.jdbc.plugin.DefaultConnectionPlugin; +// import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; +// import software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin; +// import software.amazon.jdbc.profile.ConfigurationProfile; +// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.WrapperUtils; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.wrapper.ConnectionWrapper; +// +// public class ConnectionPluginManagerTests { +// +// private static final Logger LOGGER = Logger.getLogger(ConnectionPluginManagerTests.class.getName()); +// +// @Mock JdbcCallable mockSqlFunction; +// @Mock ConnectionProvider mockConnectionProvider; +// @Mock ConnectionWrapper mockConnectionWrapper; +// @Mock TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock FullServicesContainer mockServicesContainer; +// @Mock PluginService mockPluginService; +// @Mock PluginManagerService mockPluginManagerService; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// +// ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); +// +// private AutoCloseable closeable; +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @BeforeEach +// void init() { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); +// when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); +// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); +// } +// +// @Test +// public void testExecuteJdbcCallA() throws Exception { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThree(calls)); +// +// final Properties testProperties = new Properties(); +// +// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; +// +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Object result = +// target.execute( +// String.class, +// Exception.class, +// Connection.class, +// JdbcMethod.BLOB_LENGTH, +// () -> { +// calls.add("targetCall"); +// return "resulTestValue"; +// }, +// testArgs); +// +// assertEquals("resulTestValue", result); +// +// assertEquals(7, calls.size()); +// assertEquals("TestPluginOne:before", calls.get(0)); +// assertEquals("TestPluginTwo:before", calls.get(1)); +// assertEquals("TestPluginThree:before", calls.get(2)); +// assertEquals("targetCall", calls.get(3)); +// assertEquals("TestPluginThree:after", calls.get(4)); +// assertEquals("TestPluginTwo:after", calls.get(5)); +// assertEquals("TestPluginOne:after", calls.get(6)); +// } +// +// @Test +// public void testExecuteJdbcCallB() throws Exception { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThree(calls)); +// +// final Properties testProperties = new Properties(); +// +// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; +// +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Object result = +// target.execute( +// String.class, +// Exception.class, +// Connection.class, +// JdbcMethod.BLOB_POSITION, +// () -> { +// calls.add("targetCall"); +// return "resulTestValue"; +// }, +// testArgs); +// +// assertEquals("resulTestValue", result); +// +// assertEquals(5, calls.size()); +// assertEquals("TestPluginOne:before", calls.get(0)); +// assertEquals("TestPluginTwo:before", calls.get(1)); +// assertEquals("targetCall", calls.get(2)); +// assertEquals("TestPluginTwo:after", calls.get(3)); +// assertEquals("TestPluginOne:after", calls.get(4)); +// } +// +// @Test +// public void testExecuteJdbcCallC() throws Exception { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThree(calls)); +// +// final Properties testProperties = new Properties(); +// +// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; +// +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Object result = +// target.execute( +// String.class, +// Exception.class, +// Connection.class, +// JdbcMethod.BLOB_GETBYTES, +// () -> { +// calls.add("targetCall"); +// return "resulTestValue"; +// }, +// testArgs); +// +// assertEquals("resulTestValue", result); +// +// assertEquals(3, calls.size()); +// assertEquals("TestPluginOne:before", calls.get(0)); +// assertEquals("targetCall", calls.get(1)); +// assertEquals("TestPluginOne:after", calls.get(2)); +// } +// +// @Test +// public void testConnect() throws Exception { +// +// final Connection expectedConnection = mock(Connection.class); +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThree(calls, expectedConnection)); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Connection conn = target.connect("any", +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, +// true, null); +// +// assertEquals(expectedConnection, conn); +// assertEquals(4, calls.size()); +// assertEquals("TestPluginOne:before connect", calls.get(0)); +// assertEquals("TestPluginThree:before connect", calls.get(1)); +// assertEquals("TestPluginThree:connection", calls.get(2)); +// assertEquals("TestPluginOne:after connect", calls.get(3)); +// } +// +// @Test +// public void testConnectWithSkipPlugin() throws Exception { +// +// final Connection expectedConnection = mock(Connection.class); +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// final ConnectionPlugin pluginOne = new TestPluginOne(calls); +// testPlugins.add(pluginOne); +// final ConnectionPlugin pluginTwo = new TestPluginTwo(calls); +// testPlugins.add(pluginTwo); +// final ConnectionPlugin pluginThree = new TestPluginThree(calls, expectedConnection); +// testPlugins.add(pluginThree); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Connection conn = target.connect("any", +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, +// true, pluginOne); +// +// assertEquals(expectedConnection, conn); +// assertEquals(2, calls.size()); +// assertEquals("TestPluginThree:before connect", calls.get(0)); +// assertEquals("TestPluginThree:connection", calls.get(1)); +// } +// +// @Test +// public void testForceConnect() throws Exception { +// +// final Connection expectedConnection = mock(Connection.class); +// final ArrayList calls = new ArrayList<>(); +// final ArrayList testPlugins = new ArrayList<>(); +// +// // TestPluginOne is not an AuthenticationConnectionPlugin. +// testPlugins.add(new TestPluginOne(calls)); +// +// // TestPluginTwo is an AuthenticationConnectionPlugin, but it's not subscribed to "forceConnect" method. +// testPlugins.add(new TestPluginTwo(calls)); +// +// // TestPluginThree is an AuthenticationConnectionPlugin, and it's subscribed to "forceConnect" method. +// testPlugins.add(new TestPluginThree(calls, expectedConnection)); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Connection conn = target.forceConnect("any", +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, +// true, +// null); +// +// // Expecting only TestPluginThree to participate in forceConnect(). +// assertEquals(expectedConnection, conn); +// assertEquals(4, calls.size()); +// assertEquals("TestPluginOne:before forceConnect", calls.get(0)); +// assertEquals("TestPluginThree:before forceConnect", calls.get(1)); +// assertEquals("TestPluginThree:forced connection", calls.get(2)); +// assertEquals("TestPluginOne:after forceConnect", calls.get(3)); +// } +// +// @Test +// public void testConnectWithSQLExceptionBefore() { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThrowException(calls, SQLException.class, true)); +// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// assertThrows( +// SQLException.class, +// () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), +// testProperties, true, null)); +// +// assertEquals(2, calls.size()); +// assertEquals("TestPluginOne:before connect", calls.get(0)); +// assertEquals("TestPluginThrowException:before", calls.get(1)); +// } +// +// @Test +// public void testConnectWithSQLExceptionAfter() { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThrowException(calls, SQLException.class, false)); +// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// assertThrows( +// SQLException.class, +// () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), +// testProperties, true, null)); +// +// assertEquals(5, calls.size()); +// assertEquals("TestPluginOne:before connect", calls.get(0)); +// assertEquals("TestPluginThrowException:before", calls.get(1)); +// assertEquals("TestPluginThree:before connect", calls.get(2)); +// assertEquals("TestPluginThree:connection", calls.get(3)); +// assertEquals("TestPluginThrowException:after", calls.get(4)); +// } +// +// @Test +// public void testConnectWithUnexpectedExceptionBefore() { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, true)); +// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Exception ex = +// assertThrows( +// IllegalArgumentException.class, +// () -> target.connect("any", +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), +// testProperties, true, null)); +// +// assertEquals(2, calls.size()); +// assertEquals("TestPluginOne:before connect", calls.get(0)); +// assertEquals("TestPluginThrowException:before", calls.get(1)); +// } +// +// @Test +// public void testConnectWithUnexpectedExceptionAfter() { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, false)); +// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); +// +// final Exception ex = +// assertThrows( +// IllegalArgumentException.class, +// () -> target.connect("any", +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), +// testProperties, true, null)); +// +// assertEquals(5, calls.size()); +// assertEquals("TestPluginOne:before connect", calls.get(0)); +// assertEquals("TestPluginThrowException:before", calls.get(1)); +// assertEquals("TestPluginThree:before connect", calls.get(2)); +// assertEquals("TestPluginThree:connection", calls.get(3)); +// assertEquals("TestPluginThrowException:after", calls.get(4)); +// } +// +// @Test +// public void testExecuteCachedJdbcCallA() throws Exception { +// +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThree(calls)); +// +// final Properties testProperties = new Properties(); +// +// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; +// +// final ConnectionPluginManager target = Mockito.spy( +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); +// +// Object result = +// target.execute( +// String.class, +// Exception.class, +// Connection.class, +// JdbcMethod.BLOB_LENGTH, +// () -> { +// calls.add("targetCall"); +// return "resulTestValue"; +// }, +// testArgs); +// +// assertEquals("resulTestValue", result); +// +// // The method has been called just once to generate a final lambda and cache it. +// verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); +// +// assertEquals(7, calls.size()); +// assertEquals("TestPluginOne:before", calls.get(0)); +// assertEquals("TestPluginTwo:before", calls.get(1)); +// assertEquals("TestPluginThree:before", calls.get(2)); +// assertEquals("targetCall", calls.get(3)); +// assertEquals("TestPluginThree:after", calls.get(4)); +// assertEquals("TestPluginTwo:after", calls.get(5)); +// assertEquals("TestPluginOne:after", calls.get(6)); +// +// calls.clear(); +// +// result = +// target.execute( +// String.class, +// Exception.class, +// Connection.class, +// JdbcMethod.BLOB_LENGTH, +// () -> { +// calls.add("targetCall"); +// return "anotherResulTestValue"; +// }, +// testArgs); +// +// assertEquals("anotherResulTestValue", result); +// +// // No additional calls to this method occurred. It's still been called once. +// verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); +// +// assertEquals(7, calls.size()); +// assertEquals("TestPluginOne:before", calls.get(0)); +// assertEquals("TestPluginTwo:before", calls.get(1)); +// assertEquals("TestPluginThree:before", calls.get(2)); +// assertEquals("targetCall", calls.get(3)); +// assertEquals("TestPluginThree:after", calls.get(4)); +// assertEquals("TestPluginTwo:after", calls.get(5)); +// assertEquals("TestPluginOne:after", calls.get(6)); +// } +// +// @Test +// public void testForceConnectCachedJdbcCallForceConnect() throws Exception { +// +// final ArrayList calls = new ArrayList<>(); +// final Connection mockConnection = mock(Connection.class); +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThree(calls, mockConnection)); +// +// final HostSpec testHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("test-instance").build(); +// +// final Properties testProperties = new Properties(); +// +// final ConnectionPluginManager target = Mockito.spy( +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); +// +// Object result = target.forceConnect( +// "any", +// testHostSpec, +// testProperties, +// true, +// null); +// +// assertEquals(mockConnection, result); +// +// // The method has been called just once to generate a final lambda and cache it. +// verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); +// +// assertEquals(4, calls.size()); +// assertEquals("TestPluginOne:before forceConnect", calls.get(0)); +// assertEquals("TestPluginThree:before forceConnect", calls.get(1)); +// assertEquals("TestPluginThree:forced connection", calls.get(2)); +// assertEquals("TestPluginOne:after forceConnect", calls.get(3)); +// +// calls.clear(); +// +// result = target.forceConnect( +// "any", +// testHostSpec, +// testProperties, +// true, +// null); +// +// assertEquals(mockConnection, result); +// +// // No additional calls to this method occurred. It's still been called once. +// verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); +// +// assertEquals(4, calls.size()); +// assertEquals("TestPluginOne:before forceConnect", calls.get(0)); +// assertEquals("TestPluginThree:before forceConnect", calls.get(1)); +// assertEquals("TestPluginThree:forced connection", calls.get(2)); +// assertEquals("TestPluginOne:after forceConnect", calls.get(3)); +// } +// +// @Test +// public void testExecuteAgainstOldConnection() throws Exception { +// final ArrayList calls = new ArrayList<>(); +// +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(calls)); +// testPlugins.add(new TestPluginTwo(calls)); +// testPlugins.add(new TestPluginThree(calls)); +// +// final Properties testProperties = new Properties(); +// +// final Connection mockOldConnection = mock(Connection.class); +// final Connection mockCurrentConnection = mock(Connection.class); +// final Statement mockOldStatement = mock(Statement.class); +// final ResultSet mockOldResultSet = mock(ResultSet.class); +// +// when(mockPluginService.getCurrentConnection()).thenReturn(mockCurrentConnection); +// when(mockOldStatement.getConnection()).thenReturn(mockOldConnection); +// when(mockOldResultSet.getStatement()).thenReturn(mockOldStatement); +// +// final ConnectionPluginManager target = +// new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, +// mockPluginService, mockTelemetryFactory); +// +// assertThrows(SQLException.class, +// () -> target.execute(String.class, Exception.class, mockOldConnection, +// JdbcMethod.CALLABLESTATEMENT_GETCONNECTION, () -> "result", null)); +// assertThrows(SQLException.class, +// () -> target.execute(String.class, Exception.class, mockOldStatement, +// JdbcMethod.CALLABLESTATEMENT_GETMORERESULTS, () -> "result", null)); +// assertThrows(SQLException.class, +// () -> target.execute(String.class, Exception.class, mockOldResultSet, +// JdbcMethod.RESULTSET_GETSTATEMENT, () -> "result", null)); +// +// assertDoesNotThrow( +// () -> target.execute(Void.class, SQLException.class, mockOldConnection, +// JdbcMethod.CONNECTION_CLOSE, mockSqlFunction, +// null)); +// assertDoesNotThrow( +// () -> target.execute(Void.class, SQLException.class, mockOldConnection, +// JdbcMethod.CONNECTION_ABORT, mockSqlFunction, +// null)); +// assertDoesNotThrow( +// () -> target.execute(Void.class, SQLException.class, mockOldStatement, +// JdbcMethod.STATEMENT_CLOSE, mockSqlFunction, +// null)); +// assertDoesNotThrow( +// () -> target.execute(Void.class, SQLException.class, mockOldResultSet, +// JdbcMethod.RESULTSET_CLOSE, mockSqlFunction, +// null)); +// } +// +// @Test +// public void testDefaultPlugins() throws SQLException { +// final Properties testProperties = new Properties(); +// +// final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( +// mockConnectionProvider, +// null, +// mockConnectionWrapper, +// mockTelemetryFactory)); +// target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); +// +// assertEquals(4, target.plugins.size()); +// assertEquals(AuroraConnectionTrackerPlugin.class, target.plugins.get(0).getClass()); +// assertEquals(software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin.class, +// target.plugins.get(1).getClass()); +// assertEquals(HostMonitoringConnectionPlugin.class, target.plugins.get(2).getClass()); +// assertEquals(DefaultConnectionPlugin.class, target.plugins.get(3).getClass()); +// } +// +// @Test +// public void testNoWrapperPlugins() throws SQLException { +// final Properties testProperties = new Properties(); +// testProperties.setProperty(PropertyDefinition.PLUGINS.name, ""); +// +// final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( +// mockConnectionProvider, +// null, +// mockConnectionWrapper, +// mockTelemetryFactory)); +// target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); +// +// assertEquals(1, target.plugins.size()); +// } +// +// @Test +// public void testOverridingDefaultPluginsWithPluginCodes() throws SQLException { +// final Properties testProperties = new Properties(); +// testProperties.setProperty("wrapperPlugins", "logQuery"); +// +// final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( +// mockConnectionProvider, +// null, +// mockConnectionWrapper, +// mockTelemetryFactory)); +// target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); +// +// assertEquals(2, target.plugins.size()); +// assertEquals(LogQueryConnectionPlugin.class, target.plugins.get(0).getClass()); +// assertEquals(DefaultConnectionPlugin.class, target.plugins.get(1).getClass()); +// } +// +// @Test +// public void testTwoConnectionsDoNotBlockOneAnother() throws Exception { +// +// final Properties testProperties = new Properties(); +// final ArrayList testPlugins = new ArrayList<>(); +// testPlugins.add(new TestPluginOne(new ArrayList<>())); +// +// final ConnectionProvider mockConnectionProvider1 = Mockito.mock(ConnectionProvider.class); +// final ConnectionWrapper mockConnectionWrapper1 = Mockito.mock(ConnectionWrapper.class); +// final PluginService mockPluginService1 = Mockito.mock(PluginService.class); +// final TelemetryFactory mockTelemetryFactory1 = Mockito.mock(TelemetryFactory.class); +// final Object object1 = new Object(); +// when(mockPluginService1.getTelemetryFactory()).thenReturn(mockTelemetryFactory1); +// when(mockTelemetryFactory1.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory1.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// +// final ConnectionPluginManager pluginManager1 = +// new ConnectionPluginManager(mockConnectionProvider1, +// null, testProperties, testPlugins, mockConnectionWrapper1, +// mockPluginService1, mockTelemetryFactory1); +// +// final ConnectionProvider mockConnectionProvider2 = Mockito.mock(ConnectionProvider.class); +// final ConnectionWrapper mockConnectionWrapper2 = Mockito.mock(ConnectionWrapper.class); +// final PluginService mockPluginService2 = Mockito.mock(PluginService.class); +// final TelemetryFactory mockTelemetryFactory2 = Mockito.mock(TelemetryFactory.class); +// final Object object2 = new Object(); +// when(mockPluginService2.getTelemetryFactory()).thenReturn(mockTelemetryFactory2); +// when(mockTelemetryFactory2.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory2.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// +// final ConnectionPluginManager pluginManager2 = +// new ConnectionPluginManager(mockConnectionProvider2, +// null, testProperties, testPlugins, mockConnectionWrapper2, +// mockPluginService2, mockTelemetryFactory2); +// +// // Imaginary database resource is considered "locked" when latch is 0 +// final CountDownLatch waitForDbResourceLocked = new CountDownLatch(1); +// final ReentrantLock dbResourceLock = new ReentrantLock(); +// final CountDownLatch waitForReleaseDbResourceToProceed = new CountDownLatch(1); +// final AtomicBoolean dbResourceReleased = new AtomicBoolean(false); +// final AtomicBoolean acquireDbResourceLockSuccessful = new AtomicBoolean(false); +// +// CompletableFuture.allOf( +// +// // Thread 1 +// CompletableFuture.runAsync(() -> { +// +// LOGGER.info("thread-1: started"); +// +// WrapperUtils.executeWithPlugins( +// Integer.class, +// pluginManager1, +// object1, +// JdbcMethod.BLOB_POSITION, // any JdbcMethod that locks connection +// () -> { +// dbResourceLock.lock(); +// waitForDbResourceLocked.countDown(); +// LOGGER.info("thread-1: locked"); +// return 1; +// }); +// +// LOGGER.info("thread-1: waiting for thread-2"); +// try { +// waitForReleaseDbResourceToProceed.await(); +// } catch (InterruptedException e) { +// throw new RuntimeException(e); +// } +// LOGGER.info("thread-1: continue"); +// +// WrapperUtils.executeWithPlugins( +// Integer.class, +// pluginManager1, +// object1, +// JdbcMethod.BLOB_TRUNCATE, // any JdbcMethod that locks connection +// () -> { +// dbResourceLock.unlock(); +// dbResourceReleased.set(true); +// LOGGER.info("thread-1: unlocked"); +// return 1; +// }); +// LOGGER.info("thread-1: completed"); +// }), +// +// // Thread 2 +// CompletableFuture.runAsync(() -> { +// +// LOGGER.info("thread-2: started"); +// LOGGER.info("thread-2: waiting for thread-1"); +// try { +// waitForDbResourceLocked.await(); +// } catch (InterruptedException e) { +// throw new RuntimeException(e); +// } +// LOGGER.info("thread-2: continue"); +// +// WrapperUtils.executeWithPlugins( +// Integer.class, +// pluginManager2, +// object2, +// JdbcMethod.BLOB_LENGTH, // any JdbcMethod that locks connection +// () -> { +// waitForReleaseDbResourceToProceed.countDown(); +// LOGGER.info("thread-2: try to acquire a lock"); +// try { +// acquireDbResourceLockSuccessful.set(dbResourceLock.tryLock(5, TimeUnit.SECONDS)); +// } catch (InterruptedException e) { +// throw new RuntimeException(e); +// } +// return 1; +// }); +// LOGGER.info("thread-2: completed"); +// }) +// ).join(); +// +// assertTrue(dbResourceReleased.get()); +// assertTrue(acquireDbResourceLockSuccessful.get()); +// } +// +// @Test +// public void testGetHostSpecByStrategy_givenPluginWithNoSubscriptions_thenThrowsSqlException() throws SQLException { +// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); +// when(mockPlugin.getSubscribedMethods()).thenReturn(Collections.emptySet()); +// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); +// +// final List testPlugins = Collections.singletonList(mockPlugin); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, +// mockPluginService, mockTelemetryFactory); +// +// final HostRole inputHostRole = HostRole.WRITER; +// final String inputStrategy = "someStrategy"; +// +// assertThrows( +// SQLException.class, +// () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); +// } +// +// @Test +// public void testGetHostSpecByStrategy_givenPluginWithDiffSubscription_thenThrowsSqlException() throws SQLException { +// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); +// when(mockPlugin.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); +// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); +// +// final List testPlugins = Collections.singletonList(mockPlugin); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, +// mockPluginService, mockTelemetryFactory); +// +// final HostRole inputHostRole = HostRole.WRITER; +// final String inputStrategy = "someStrategy"; +// +// assertThrows( +// SQLException.class, +// () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); +// } +// +// @Test +// public void testGetHostSpecByStrategy_givenUnsupportedPlugin_thenThrowsSqlException() throws SQLException { +// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); +// when(mockPlugin.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); +// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); +// +// final List testPlugins = Collections.singletonList(mockPlugin); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, +// mockPluginService, mockTelemetryFactory); +// +// final HostRole inputHostRole = HostRole.WRITER; +// final String inputStrategy = "someStrategy"; +// +// assertThrows( +// SQLException.class, +// () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); +// } +// +// @Test +// public void testGetHostSpecByStrategy_givenSupportedSubscribedPlugin_thenThrowsSqlException() throws SQLException { +// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); +// +// when(mockPlugin.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); +// +// final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("expected-instance").build(); +// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); +// +// final List testPlugins = Collections.singletonList(mockPlugin); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, +// mockPluginService, mockTelemetryFactory); +// +// final HostRole inputHostRole = HostRole.WRITER; +// final String inputStrategy = "someStrategy"; +// final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); +// +// verify(mockPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); +// assertEquals(expectedHostSpec, actualHostSpec); +// } +// +// @Test +// public void testGetHostSpecByStrategy_givenMultiplePlugins() throws SQLException { +// final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); +// final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); +// final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); +// final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); +// final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); +// +// final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, +// unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); +// +// when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); +// when(unsubscribedPlugin1.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); +// when(unsupportedSubscribedPlugin0.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); +// when(unsupportedSubscribedPlugin1.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); +// when(supportedSubscribedPlugin.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); +// +// when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); +// when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); +// when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any())) +// .thenThrow(new UnsupportedOperationException()); +// when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any())) +// .thenThrow(new UnsupportedOperationException()); +// +// final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("expected-instance").build(); +// when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, +// mockPluginService, mockTelemetryFactory); +// +// final HostRole inputHostRole = HostRole.WRITER; +// final String inputStrategy = "someStrategy"; +// final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); +// +// verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); +// assertEquals(expectedHostSpec, actualHostSpec); +// } +// +// @Test +// public void testGetHostSpecByStrategy_givenInputHostsAndMultiplePlugins() throws SQLException { +// final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); +// final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); +// final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); +// final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); +// final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); +// +// final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, +// unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); +// +// when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); +// when(unsubscribedPlugin1.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); +// when(unsupportedSubscribedPlugin0.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); +// when(unsupportedSubscribedPlugin1.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); +// when(supportedSubscribedPlugin.getSubscribedMethods()) +// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); +// +// when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); +// when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); +// when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())) +// .thenThrow(new UnsupportedOperationException()); +// when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())) +// .thenThrow(new UnsupportedOperationException()); +// +// final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("expected-instance").build(); +// when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any(), any())).thenReturn(expectedHostSpec); +// +// final Properties testProperties = new Properties(); +// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, +// null, testProperties, testPlugins, mockConnectionWrapper, +// mockPluginService, mockTelemetryFactory); +// +// final List inputHosts = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("expected-instance").build()); +// final HostRole inputHostRole = HostRole.WRITER; +// final String inputStrategy = "someStrategy"; +// final HostSpec actualHostSpec = +// connectionPluginManager.getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); +// +// verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); +// assertEquals(expectedHostSpec, actualHostSpec); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java index 3b47f12bb..735cebd67 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java @@ -1,289 +1,289 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.Properties; -import java.util.stream.Stream; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.dialect.AuroraMysqlDialect; -import software.amazon.jdbc.dialect.AuroraPgDialect; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.dialect.DialectManager; -import software.amazon.jdbc.dialect.MariaDbDialect; -import software.amazon.jdbc.dialect.MysqlDialect; -import software.amazon.jdbc.dialect.PgDialect; -import software.amazon.jdbc.dialect.RdsMultiAzDbClusterMysqlDialect; -import software.amazon.jdbc.dialect.RdsMultiAzDbClusterPgDialect; -import software.amazon.jdbc.dialect.RdsMysqlDialect; -import software.amazon.jdbc.dialect.RdsPgDialect; -import software.amazon.jdbc.exceptions.ExceptionManager; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.storage.StorageService; - -public class DialectDetectionTests { - private static final String LOCALHOST = "localhost"; - private static final String RDS_DATABASE = "database-1.xyz.us-east-2.rds.amazonaws.com"; - private static final String RDS_AURORA_DATABASE = "database-2.cluster-xyz.us-east-2.rds.amazonaws.com"; - private static final String MYSQL_PROTOCOL = "jdbc:mysql://"; - private static final String PG_PROTOCOL = "jdbc:postgresql://"; - private static final String MARIA_PROTOCOL = "jdbc:mariadb://"; - private final Properties props = new Properties(); - private AutoCloseable closeable; - @Mock private FullServicesContainer mockServicesContainer; - @Mock private HostListProviderService mockHostListProviderService; - @Mock private StorageService mockStorageService; - @Mock private Connection mockConnection; - @Mock private Statement mockStatement; - @Mock private ResultSet mockSuccessResultSet; - @Mock private ResultSet mockFailResultSet; - @Mock private HostSpec mockHost; - @Mock private ConnectionPluginManager mockPluginManager; - @Mock private TargetDriverDialect mockTargetDriverDialect; - @Mock private ResultSetMetaData mockResultSetMetaData; - - @BeforeEach - void setUp() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - when(this.mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); - when(this.mockServicesContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); - when(this.mockServicesContainer.getStorageService()).thenReturn(mockStorageService); - when(this.mockConnection.createStatement()).thenReturn(this.mockStatement); - when(this.mockHost.getUrl()).thenReturn("url"); - when(this.mockFailResultSet.next()).thenReturn(false); - mockPluginManager.plugins = new ArrayList<>(); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - DialectManager.resetEndpointCache(); - } - - PluginServiceImpl getPluginService(String host, String protocol) throws SQLException { - PluginServiceImpl pluginService = spy( - new PluginServiceImpl( - mockServicesContainer, - new ExceptionManager(), - props, - protocol + host, - protocol, - null, - mockTargetDriverDialect, - null, - null)); - - when(this.mockServicesContainer.getHostListProviderService()).thenReturn(pluginService); - return pluginService; - } - - @ParameterizedTest - @MethodSource("getInitialDialectArguments") - public void testInitialDialectDetection(String protocol, String host, Object expectedDialect) throws SQLException { - final DialectManager dialectManager = new DialectManager(this.getPluginService(host, protocol)); - final Dialect dialect = dialectManager.getDialect(protocol, host, new Properties()); - assertEquals(expectedDialect, dialect.getClass()); - } - - static Stream getInitialDialectArguments() { - return Stream.of( - Arguments.of(MYSQL_PROTOCOL, LOCALHOST, MysqlDialect.class), - Arguments.of(MYSQL_PROTOCOL, RDS_DATABASE, RdsMysqlDialect.class), - Arguments.of(MYSQL_PROTOCOL, RDS_AURORA_DATABASE, AuroraMysqlDialect.class), - Arguments.of(PG_PROTOCOL, LOCALHOST, PgDialect.class), - Arguments.of(PG_PROTOCOL, RDS_DATABASE, RdsPgDialect.class), - Arguments.of(PG_PROTOCOL, RDS_AURORA_DATABASE, AuroraPgDialect.class), - Arguments.of(MARIA_PROTOCOL, LOCALHOST, MariaDbDialect.class), - Arguments.of(MARIA_PROTOCOL, RDS_DATABASE, MariaDbDialect.class), - Arguments.of(MARIA_PROTOCOL, RDS_AURORA_DATABASE, MariaDbDialect.class) - ); - } - - @Test - void testUpdateDialectMysqlUnchanged() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); - final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(MysqlDialect.class, target.dialect.getClass()); - } - - @Test - void testUpdateDialectMysqlToRds() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); - when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); - when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); - when(mockSuccessResultSet.getString(2)).thenReturn( - "Source distribution", "Source distribution", ""); - when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); - when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - when(mockFailResultSet.next()).thenReturn(false); - final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); - } - - @Test - @Disabled - // TODO: fix me: need to split this test into two separate tests: - // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect - // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() - void testUpdateDialectMysqlToTaz() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); - when(mockSuccessResultSet.next()).thenReturn(true); - final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); - } - - @Test - void testUpdateDialectMysqlToAurora() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); - when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); - when(mockSuccessResultSet.next()).thenReturn(true, false); - final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); - when(mockServicesContainer.getPluginService()).thenReturn(target); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); - } - - @Test - void testUpdateDialectPgUnchanged() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); - final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(PgDialect.class, target.dialect.getClass()); - } - - @Test - void testUpdateDialectPgToRds() throws SQLException { - when(mockStatement.executeQuery(any())) - .thenReturn(mockSuccessResultSet, mockFailResultSet, mockFailResultSet, mockSuccessResultSet); - when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); - when(mockSuccessResultSet.getBoolean("rds_tools")).thenReturn(true); - when(mockSuccessResultSet.getBoolean("aurora_stat_utils")).thenReturn(false); - when(mockSuccessResultSet.next()).thenReturn(true); - when(mockFailResultSet.next()).thenReturn(false); - final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(RdsPgDialect.class, target.dialect.getClass()); - } - - @Test - @Disabled - // TODO: fix me: need to split this test into two separate tests: - // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect - // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() - void testUpdateDialectPgToTaz() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); - when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); - when(mockSuccessResultSet.next()).thenReturn(true); - final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(RdsMultiAzDbClusterPgDialect.class, target.dialect.getClass()); - } - - @Test - @Disabled - // TODO: fix me: need to split this test into two separate tests: - // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect - // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() - void testUpdateDialectPgToAurora() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); - when(mockSuccessResultSet.next()).thenReturn(true); - when(mockSuccessResultSet.getBoolean(any())).thenReturn(true); - final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(AuroraPgDialect.class, target.dialect.getClass()); - } - - @Test - void testUpdateDialectMariaUnchanged() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); - final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(MariaDbDialect.class, target.dialect.getClass()); - } - - @Test - void testUpdateDialectMariaToMysqlRds() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); - when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); - when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); - when(mockSuccessResultSet.getString(2)).thenReturn( - "Source distribution", "Source distribution", ""); - when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); - when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - when(mockFailResultSet.next()).thenReturn(false); - final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); - } - - @Test - @Disabled - // TODO: fix me: need to split this test into two separate tests: - // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect - // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() - void testUpdateDialectMariaToMysqlTaz() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); - final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(RdsMultiAzDbClusterMysqlDialect.class, target.dialect.getClass()); - } - - @Test - void testUpdateDialectMariaToMysqlAurora() throws SQLException { - when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); - when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); - when(mockSuccessResultSet.next()).thenReturn(true, false); - final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); - when(mockServicesContainer.getPluginService()).thenReturn(target); - target.setInitialConnectionHostSpec(mockHost); - target.updateDialect(mockConnection); - assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.ResultSetMetaData; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.ArrayList; +// import java.util.Properties; +// import java.util.stream.Stream; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Disabled; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.Arguments; +// import org.junit.jupiter.params.provider.MethodSource; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.dialect.AuroraMysqlDialect; +// import software.amazon.jdbc.dialect.AuroraPgDialect; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.dialect.DialectManager; +// import software.amazon.jdbc.dialect.MariaDbDialect; +// import software.amazon.jdbc.dialect.MysqlDialect; +// import software.amazon.jdbc.dialect.PgDialect; +// import software.amazon.jdbc.dialect.RdsMultiAzDbClusterMysqlDialect; +// import software.amazon.jdbc.dialect.RdsMultiAzDbClusterPgDialect; +// import software.amazon.jdbc.dialect.RdsMysqlDialect; +// import software.amazon.jdbc.dialect.RdsPgDialect; +// import software.amazon.jdbc.exceptions.ExceptionManager; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.storage.StorageService; +// +// public class DialectDetectionTests { +// private static final String LOCALHOST = "localhost"; +// private static final String RDS_DATABASE = "database-1.xyz.us-east-2.rds.amazonaws.com"; +// private static final String RDS_AURORA_DATABASE = "database-2.cluster-xyz.us-east-2.rds.amazonaws.com"; +// private static final String MYSQL_PROTOCOL = "jdbc:mysql://"; +// private static final String PG_PROTOCOL = "jdbc:postgresql://"; +// private static final String MARIA_PROTOCOL = "jdbc:mariadb://"; +// private final Properties props = new Properties(); +// private AutoCloseable closeable; +// @Mock private FullServicesContainer mockServicesContainer; +// @Mock private HostListProviderService mockHostListProviderService; +// @Mock private StorageService mockStorageService; +// @Mock private Connection mockConnection; +// @Mock private Statement mockStatement; +// @Mock private ResultSet mockSuccessResultSet; +// @Mock private ResultSet mockFailResultSet; +// @Mock private HostSpec mockHost; +// @Mock private ConnectionPluginManager mockPluginManager; +// @Mock private TargetDriverDialect mockTargetDriverDialect; +// @Mock private ResultSetMetaData mockResultSetMetaData; +// +// @BeforeEach +// void setUp() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// when(this.mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); +// when(this.mockServicesContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); +// when(this.mockServicesContainer.getStorageService()).thenReturn(mockStorageService); +// when(this.mockConnection.createStatement()).thenReturn(this.mockStatement); +// when(this.mockHost.getUrl()).thenReturn("url"); +// when(this.mockFailResultSet.next()).thenReturn(false); +// mockPluginManager.plugins = new ArrayList<>(); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// DialectManager.resetEndpointCache(); +// } +// +// PluginServiceImpl getPluginService(String host, String protocol) throws SQLException { +// PluginServiceImpl pluginService = spy( +// new PluginServiceImpl( +// mockServicesContainer, +// new ExceptionManager(), +// props, +// protocol + host, +// protocol, +// null, +// mockTargetDriverDialect, +// null, +// null)); +// +// when(this.mockServicesContainer.getHostListProviderService()).thenReturn(pluginService); +// return pluginService; +// } +// +// @ParameterizedTest +// @MethodSource("getInitialDialectArguments") +// public void testInitialDialectDetection(String protocol, String host, Object expectedDialect) throws SQLException { +// final DialectManager dialectManager = new DialectManager(this.getPluginService(host, protocol)); +// final Dialect dialect = dialectManager.getDialect(protocol, host, new Properties()); +// assertEquals(expectedDialect, dialect.getClass()); +// } +// +// static Stream getInitialDialectArguments() { +// return Stream.of( +// Arguments.of(MYSQL_PROTOCOL, LOCALHOST, MysqlDialect.class), +// Arguments.of(MYSQL_PROTOCOL, RDS_DATABASE, RdsMysqlDialect.class), +// Arguments.of(MYSQL_PROTOCOL, RDS_AURORA_DATABASE, AuroraMysqlDialect.class), +// Arguments.of(PG_PROTOCOL, LOCALHOST, PgDialect.class), +// Arguments.of(PG_PROTOCOL, RDS_DATABASE, RdsPgDialect.class), +// Arguments.of(PG_PROTOCOL, RDS_AURORA_DATABASE, AuroraPgDialect.class), +// Arguments.of(MARIA_PROTOCOL, LOCALHOST, MariaDbDialect.class), +// Arguments.of(MARIA_PROTOCOL, RDS_DATABASE, MariaDbDialect.class), +// Arguments.of(MARIA_PROTOCOL, RDS_AURORA_DATABASE, MariaDbDialect.class) +// ); +// } +// +// @Test +// void testUpdateDialectMysqlUnchanged() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(MysqlDialect.class, target.dialect.getClass()); +// } +// +// @Test +// void testUpdateDialectMysqlToRds() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); +// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); +// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); +// when(mockSuccessResultSet.getString(2)).thenReturn( +// "Source distribution", "Source distribution", ""); +// when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); +// when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); +// when(mockFailResultSet.next()).thenReturn(false); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); +// } +// +// @Test +// @Disabled +// // TODO: fix me: need to split this test into two separate tests: +// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect +// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() +// void testUpdateDialectMysqlToTaz() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); +// when(mockSuccessResultSet.next()).thenReturn(true); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); +// } +// +// @Test +// void testUpdateDialectMysqlToAurora() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); +// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); +// when(mockSuccessResultSet.next()).thenReturn(true, false); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); +// when(mockServicesContainer.getPluginService()).thenReturn(target); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); +// } +// +// @Test +// void testUpdateDialectPgUnchanged() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); +// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(PgDialect.class, target.dialect.getClass()); +// } +// +// @Test +// void testUpdateDialectPgToRds() throws SQLException { +// when(mockStatement.executeQuery(any())) +// .thenReturn(mockSuccessResultSet, mockFailResultSet, mockFailResultSet, mockSuccessResultSet); +// when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); +// when(mockSuccessResultSet.getBoolean("rds_tools")).thenReturn(true); +// when(mockSuccessResultSet.getBoolean("aurora_stat_utils")).thenReturn(false); +// when(mockSuccessResultSet.next()).thenReturn(true); +// when(mockFailResultSet.next()).thenReturn(false); +// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(RdsPgDialect.class, target.dialect.getClass()); +// } +// +// @Test +// @Disabled +// // TODO: fix me: need to split this test into two separate tests: +// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect +// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() +// void testUpdateDialectPgToTaz() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); +// when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); +// when(mockSuccessResultSet.next()).thenReturn(true); +// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(RdsMultiAzDbClusterPgDialect.class, target.dialect.getClass()); +// } +// +// @Test +// @Disabled +// // TODO: fix me: need to split this test into two separate tests: +// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect +// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() +// void testUpdateDialectPgToAurora() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); +// when(mockSuccessResultSet.next()).thenReturn(true); +// when(mockSuccessResultSet.getBoolean(any())).thenReturn(true); +// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(AuroraPgDialect.class, target.dialect.getClass()); +// } +// +// @Test +// void testUpdateDialectMariaUnchanged() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(MariaDbDialect.class, target.dialect.getClass()); +// } +// +// @Test +// void testUpdateDialectMariaToMysqlRds() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); +// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); +// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); +// when(mockSuccessResultSet.getString(2)).thenReturn( +// "Source distribution", "Source distribution", ""); +// when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); +// when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); +// when(mockFailResultSet.next()).thenReturn(false); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); +// } +// +// @Test +// @Disabled +// // TODO: fix me: need to split this test into two separate tests: +// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect +// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() +// void testUpdateDialectMariaToMysqlTaz() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(RdsMultiAzDbClusterMysqlDialect.class, target.dialect.getClass()); +// } +// +// @Test +// void testUpdateDialectMariaToMysqlAurora() throws SQLException { +// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); +// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); +// when(mockSuccessResultSet.next()).thenReturn(true, false); +// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); +// when(mockServicesContainer.getPluginService()).thenReturn(target); +// target.setInitialConnectionHostSpec(mockHost); +// target.updateDialect(mockConnection); +// assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java index 07a22941f..4b234a741 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java @@ -1,946 +1,946 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import java.util.stream.Stream; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.dialect.AuroraPgDialect; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.dialect.DialectManager; -import software.amazon.jdbc.dialect.MysqlDialect; -import software.amazon.jdbc.exceptions.ExceptionManager; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.profile.ConfigurationProfile; -import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -import software.amazon.jdbc.states.SessionStateService; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.events.EventPublisher; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.storage.TestStorageServiceImpl; - -public class PluginServiceImplTests { - - private static final Properties PROPERTIES = new Properties(); - private static final String URL = "url"; - private static final String DRIVER_PROTOCOL = "driverProtocol"; - private StorageService storageService; - private AutoCloseable closeable; - - @Mock FullServicesContainer servicesContainer; - @Mock EventPublisher mockEventPublisher; - @Mock ConnectionPluginManager pluginManager; - @Mock Connection newConnection; - @Mock Connection oldConnection; - @Mock HostListProvider hostListProvider; - @Mock DialectManager dialectManager; - @Mock TargetDriverDialect mockTargetDriverDialect; - @Mock Statement statement; - @Mock ResultSet resultSet; - ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); - @Mock SessionStateService sessionStateService; - - @Captor ArgumentCaptor> argumentChanges; - @Captor ArgumentCaptor>> argumentChangesMap; - @Captor ArgumentCaptor argumentSkipPlugin; - - @BeforeEach - void setUp() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - when(oldConnection.isClosed()).thenReturn(false); - when(newConnection.createStatement()).thenReturn(statement); - when(statement.executeQuery(any())).thenReturn(resultSet); - when(servicesContainer.getConnectionPluginManager()).thenReturn(pluginManager); - when(servicesContainer.getStorageService()).thenReturn(storageService); - storageService = new TestStorageServiceImpl(mockEventPublisher); - PluginServiceImpl.hostAvailabilityExpiringCache.clear(); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - storageService.clearAll(); - PluginServiceImpl.hostAvailabilityExpiringCache.clear(); - } - - @Test - public void testOldConnectionNoSuggestion() throws SQLException { - when(pluginManager.notifyConnectionChanged(any(), any())) - .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") - .build(); - - target.setCurrentConnection(newConnection, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); - - assertNotEquals(oldConnection, target.currentConnection); - assertEquals(newConnection, target.currentConnection); - assertEquals("new-host", target.currentHostSpec.getHost()); - verify(oldConnection, times(1)).close(); - } - - @Test - public void testOldConnectionDisposeSuggestion() throws SQLException { - when(pluginManager.notifyConnectionChanged(any(), any())) - .thenReturn(EnumSet.of(OldConnectionSuggestedAction.DISPOSE)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") - .build(); - - target.setCurrentConnection(newConnection, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); - - assertNotEquals(oldConnection, target.currentConnection); - assertEquals(newConnection, target.currentConnection); - assertEquals("new-host", target.currentHostSpec.getHost()); - verify(oldConnection, times(1)).close(); - } - - @Test - public void testOldConnectionPreserveSuggestion() throws SQLException { - when(pluginManager.notifyConnectionChanged(any(), any())) - .thenReturn(EnumSet.of(OldConnectionSuggestedAction.PRESERVE)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") - .build(); - - target.setCurrentConnection(newConnection, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); - - assertNotEquals(oldConnection, target.currentConnection); - assertEquals(newConnection, target.currentConnection); - assertEquals("new-host", target.currentHostSpec.getHost()); - verify(oldConnection, times(0)).close(); - } - - @Test - public void testOldConnectionMixedSuggestion() throws SQLException { - when(pluginManager.notifyConnectionChanged(any(), any())) - .thenReturn( - EnumSet.of( - OldConnectionSuggestedAction.NO_OPINION, - OldConnectionSuggestedAction.PRESERVE, - OldConnectionSuggestedAction.DISPOSE)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") - .build(); - - target.setCurrentConnection(newConnection, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); - - assertNotEquals(oldConnection, target.currentConnection); - assertEquals(newConnection, target.currentConnection); - assertEquals("new-host", target.currentHostSpec.getHost()); - verify(oldConnection, times(0)).close(); - } - - @Test - public void testChangesNewConnectionNewHostNewPortNewRoleNewAvailability() throws SQLException { - when(pluginManager.notifyConnectionChanged( - argumentChanges.capture(), argumentSkipPlugin.capture())) - .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).build(); - - target.setCurrentConnection( - newConnection, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("new-host").port(2000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build()); - - assertNull(argumentSkipPlugin.getValue()); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); - } - - @Test - public void testChangesNewConnectionNewRoleNewAvailability() throws SQLException { - when(pluginManager.notifyConnectionChanged( - argumentChanges.capture(), argumentSkipPlugin.capture())) - .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build(); - - target.setCurrentConnection(newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE) - .build()); - - assertNull(argumentSkipPlugin.getValue()); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); - } - - @Test - public void testChangesNewConnection() throws SQLException { - when(pluginManager.notifyConnectionChanged( - argumentChanges.capture(), argumentSkipPlugin.capture())) - .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) - .build(); - - target.setCurrentConnection( - newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) - .build()); - - assertNull(argumentSkipPlugin.getValue()); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); - assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); - assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); - } - - @Test - public void testChangesNoChanges() throws SQLException { - when(pluginManager.notifyConnectionChanged(any(), any())).thenReturn( - EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); - - PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.currentConnection = oldConnection; - target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(); - - target.setCurrentConnection( - oldConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) - .build()); - - verify(pluginManager, times(0)).notifyConnectionChanged(any(), any()); - } - - @Test - public void testSetNodeListAdded() throws SQLException { - - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - when(hostListProvider.refresh()).thenReturn(Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build())); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.allHosts = new ArrayList<>(); - target.hostListProvider = hostListProvider; - - target.refreshHostList(); - - assertEquals(1, target.getAllHosts().size()); - assertEquals("hostA", target.getAllHosts().get(0).getHost()); - verify(pluginManager, times(1)).notifyNodeListChanged(any()); - - Map> notifiedChanges = argumentChangesMap.getValue(); - assertTrue(notifiedChanges.containsKey("hostA/")); - EnumSet hostAChanges = notifiedChanges.get("hostA/"); - assertEquals(1, hostAChanges.size()); - assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_ADDED)); - } - - @Test - public void testSetNodeListDeleted() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - when(hostListProvider.refresh()).thenReturn(Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build())); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.allHosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build()); - target.hostListProvider = hostListProvider; - - target.refreshHostList(); - - assertEquals(1, target.getAllHosts().size()); - assertEquals("hostB", target.getAllHosts().get(0).getHost()); - verify(pluginManager, times(1)).notifyNodeListChanged(any()); - - Map> notifiedChanges = argumentChangesMap.getValue(); - assertTrue(notifiedChanges.containsKey("hostA/")); - EnumSet hostAChanges = notifiedChanges.get("hostA/"); - assertEquals(1, hostAChanges.size()); - assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_DELETED)); - } - - @Test - public void testSetNodeListChanged() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - when(hostListProvider.refresh()).thenReturn( - Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") - .port(HostSpec.NO_PORT).role(HostRole.READER).build())); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.WRITER).build()); - target.hostListProvider = hostListProvider; - - target.refreshHostList(); - - assertEquals(1, target.getAllHosts().size()); - assertEquals("hostA", target.getAllHosts().get(0).getHost()); - verify(pluginManager, times(1)).notifyNodeListChanged(any()); - - Map> notifiedChanges = argumentChangesMap.getValue(); - assertTrue(notifiedChanges.containsKey("hostA/")); - EnumSet hostAChanges = notifiedChanges.get("hostA/"); - assertEquals(2, hostAChanges.size()); - assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); - assertTrue(hostAChanges.contains(NodeChangeOptions.PROMOTED_TO_READER)); - } - - @Test - public void testSetNodeListNoChanges() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(any()); - - when(hostListProvider.refresh()).thenReturn( - Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build())); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build()); - target.hostListProvider = hostListProvider; - - target.refreshHostList(); - - assertEquals(1, target.getAllHosts().size()); - assertEquals("hostA", target.getAllHosts().get(0).getHost()); - verify(pluginManager, times(0)).notifyNodeListChanged(any()); - } - - @Test - public void testNodeAvailabilityNotChanged() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.allHosts = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) - .build()); - - Set aliases = new HashSet<>(); - aliases.add("hostA"); - target.setAvailability(aliases, HostAvailability.AVAILABLE); - - assertEquals(1, target.getAllHosts().size()); - assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); - verify(pluginManager, never()).notifyNodeListChanged(any()); - } - - @Test - public void testNodeAvailabilityChanged_WentDown() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.allHosts = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) - .build()); - - Set aliases = new HashSet<>(); - aliases.add("hostA"); - target.setAvailability(aliases, HostAvailability.NOT_AVAILABLE); - - assertEquals(1, target.getAllHosts().size()); - assertEquals(HostAvailability.NOT_AVAILABLE, target.getAllHosts().get(0).getAvailability()); - verify(pluginManager, times(1)).notifyNodeListChanged(any()); - - Map> notifiedChanges = argumentChangesMap.getValue(); - assertTrue(notifiedChanges.containsKey("hostA/")); - EnumSet hostAChanges = notifiedChanges.get("hostA/"); - assertEquals(2, hostAChanges.size()); - assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); - assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_DOWN)); - } - - @Test - public void testNodeAvailabilityChanged_WentUp() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.allHosts = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build()); - - Set aliases = new HashSet<>(); - aliases.add("hostA"); - target.setAvailability(aliases, HostAvailability.AVAILABLE); - - assertEquals(1, target.getAllHosts().size()); - assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); - verify(pluginManager, times(1)).notifyNodeListChanged(any()); - - Map> notifiedChanges = argumentChangesMap.getValue(); - assertTrue(notifiedChanges.containsKey("hostA/")); - EnumSet hostAChanges = notifiedChanges.get("hostA/"); - assertEquals(2, hostAChanges.size()); - assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); - assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); - } - - @Test - public void testNodeAvailabilityChanged_WentUp_ByAlias() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build(); - hostA.addAlias("ip-10-10-10-10"); - hostA.addAlias("hostA.custom.domain.com"); - final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build(); - hostB.addAlias("ip-10-10-10-10"); - hostB.addAlias("hostB.custom.domain.com"); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - - target.allHosts = Arrays.asList(hostA, hostB); - - Set aliases = new HashSet<>(); - aliases.add("hostA.custom.domain.com"); - target.setAvailability(aliases, HostAvailability.AVAILABLE); - - assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); - assertEquals(HostAvailability.NOT_AVAILABLE, hostB.getAvailability()); - verify(pluginManager, times(1)).notifyNodeListChanged(any()); - - Map> notifiedChanges = argumentChangesMap.getValue(); - assertTrue(notifiedChanges.containsKey("hostA/")); - EnumSet hostAChanges = notifiedChanges.get("hostA/"); - assertEquals(2, hostAChanges.size()); - assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); - assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); - } - - @Test - public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQLException { - doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - - final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build();; - hostA.addAlias("ip-10-10-10-10"); - hostA.addAlias("hostA.custom.domain.com"); - final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build(); - hostB.addAlias("ip-10-10-10-10"); - hostB.addAlias("hostB.custom.domain.com"); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - - target.allHosts = Arrays.asList(hostA, hostB); - - Set aliases = new HashSet<>(); - aliases.add("ip-10-10-10-10"); - target.setAvailability(aliases, HostAvailability.AVAILABLE); - - assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); - assertEquals(HostAvailability.AVAILABLE, hostB.getAvailability()); - verify(pluginManager, times(1)).notifyNodeListChanged(any()); - - Map> notifiedChanges = argumentChangesMap.getValue(); - assertTrue(notifiedChanges.containsKey("hostA/")); - EnumSet hostAChanges = notifiedChanges.get("hostA/"); - assertEquals(2, hostAChanges.size()); - assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); - assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); - - assertTrue(notifiedChanges.containsKey("hostB/")); - EnumSet hostBChanges = notifiedChanges.get("hostB/"); - assertEquals(2, hostBChanges.size()); - assertTrue(hostBChanges.contains(NodeChangeOptions.NODE_CHANGED)); - assertTrue(hostBChanges.contains(NodeChangeOptions.WENT_UP)); - } - - @Test - void testRefreshHostList_withCachedHostAvailability() throws SQLException { - final List newHostSpecs = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() - ); - final List newHostSpecs2 = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() - ); - final List expectedHostSpecs = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() - ); - final List expectedHostSpecs2 = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() - ); - - PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, - PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); - PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, - PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); - when(hostListProvider.refresh()).thenReturn(newHostSpecs); - when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs2); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - when(target.getHostListProvider()).thenReturn(hostListProvider); - - assertNotEquals(expectedHostSpecs, newHostSpecs); - target.refreshHostList(); - assertEquals(expectedHostSpecs, newHostSpecs); - - PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, - PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); - target.refreshHostList(newConnection); - assertEquals(expectedHostSpecs2, newHostSpecs); - } - - @Test - void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { - final List newHostSpecs = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() - ); - final List expectedHostSpecs = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() - ); - final List expectedHostSpecs2 = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) - .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() - ); - - PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, - PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); - PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, - PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); - when(hostListProvider.forceRefresh()).thenReturn(newHostSpecs); - when(hostListProvider.forceRefresh(newConnection)).thenReturn(newHostSpecs); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - when(target.getHostListProvider()).thenReturn(hostListProvider); - - assertNotEquals(expectedHostSpecs, newHostSpecs); - target.forceRefreshHostList(); - assertEquals(expectedHostSpecs, newHostSpecs); - - PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, - PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); - target.forceRefreshHostList(newConnection); - assertEquals(expectedHostSpecs2, newHostSpecs); - } - - @Test - void testIdentifyConnectionWithNoAliases() throws SQLException { - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - when(target.getHostListProvider()).thenReturn(hostListProvider); - - when(target.getDialect()).thenReturn(new MysqlDialect()); - assertNull(target.identifyConnection(newConnection)); - } - - @Test - void testIdentifyConnectionWithAliases() throws SQLException { - final HostSpec expected = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("test") - .build(); - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.hostListProvider = hostListProvider; - when(target.getHostListProvider()).thenReturn(hostListProvider); - when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(expected); - - when(target.getDialect()).thenReturn(new AuroraPgDialect()); - final HostSpec actual = target.identifyConnection(newConnection); - verify(target, never()).getCurrentHostSpec(); - verify(hostListProvider).identifyConnection(newConnection); - assertEquals(expected, actual); - } - - @Test - void testFillAliasesNonEmptyAliases() throws SQLException { - final HostSpec oneAlias = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo") - .build(); - oneAlias.addAlias(oneAlias.asAlias()); - - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - - assertEquals(1, oneAlias.getAliases().size()); - target.fillAliases(newConnection, oneAlias); - // Fill aliases should return directly and no additional aliases should be added. - assertEquals(1, oneAlias.getAliases().size()); - } - - @ParameterizedTest - @MethodSource("fillAliasesDialects") - void testFillAliasesWithInstanceEndpoint(Dialect dialect, String[] expectedInstanceAliases) throws SQLException { - final HostSpec empty = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo").build(); - PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); - target.hostListProvider = hostListProvider; - when(target.getDialect()).thenReturn(dialect); - when(resultSet.next()).thenReturn(true, false); // Result set contains 1 row. - when(resultSet.getString(eq(1))).thenReturn("ip"); - if (dialect instanceof AuroraPgDialect) { - when(hostListProvider.identifyConnection(eq(newConnection))) - .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance").build()); - } - - target.fillAliases(newConnection, empty); - - final String[] aliases = empty.getAliases().toArray(new String[] {}); - assertArrayEquals(expectedInstanceAliases, aliases); - } - - private static Stream fillAliasesDialects() { - return Stream.of( - Arguments.of(new AuroraPgDialect(), new String[]{"instance", "foo", "ip"}), - Arguments.of(new MysqlDialect(), new String[]{"foo", "ip"}) - ); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc; +// +// import static org.junit.jupiter.api.Assertions.assertArrayEquals; +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertNotEquals; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doNothing; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.HashSet; +// import java.util.List; +// import java.util.Map; +// import java.util.Properties; +// import java.util.Set; +// import java.util.stream.Stream; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.Arguments; +// import org.junit.jupiter.params.provider.MethodSource; +// import org.mockito.ArgumentCaptor; +// import org.mockito.Captor; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.dialect.AuroraPgDialect; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.dialect.DialectManager; +// import software.amazon.jdbc.dialect.MysqlDialect; +// import software.amazon.jdbc.exceptions.ExceptionManager; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.profile.ConfigurationProfile; +// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +// import software.amazon.jdbc.states.SessionStateService; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.events.EventPublisher; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.storage.TestStorageServiceImpl; +// +// public class PluginServiceImplTests { +// +// private static final Properties PROPERTIES = new Properties(); +// private static final String URL = "url"; +// private static final String DRIVER_PROTOCOL = "driverProtocol"; +// private StorageService storageService; +// private AutoCloseable closeable; +// +// @Mock FullServicesContainer servicesContainer; +// @Mock EventPublisher mockEventPublisher; +// @Mock ConnectionPluginManager pluginManager; +// @Mock Connection newConnection; +// @Mock Connection oldConnection; +// @Mock HostListProvider hostListProvider; +// @Mock DialectManager dialectManager; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// @Mock Statement statement; +// @Mock ResultSet resultSet; +// ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); +// @Mock SessionStateService sessionStateService; +// +// @Captor ArgumentCaptor> argumentChanges; +// @Captor ArgumentCaptor>> argumentChangesMap; +// @Captor ArgumentCaptor argumentSkipPlugin; +// +// @BeforeEach +// void setUp() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// when(oldConnection.isClosed()).thenReturn(false); +// when(newConnection.createStatement()).thenReturn(statement); +// when(statement.executeQuery(any())).thenReturn(resultSet); +// when(servicesContainer.getConnectionPluginManager()).thenReturn(pluginManager); +// when(servicesContainer.getStorageService()).thenReturn(storageService); +// storageService = new TestStorageServiceImpl(mockEventPublisher); +// PluginServiceImpl.hostAvailabilityExpiringCache.clear(); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// storageService.clearAll(); +// PluginServiceImpl.hostAvailabilityExpiringCache.clear(); +// } +// +// @Test +// public void testOldConnectionNoSuggestion() throws SQLException { +// when(pluginManager.notifyConnectionChanged(any(), any())) +// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") +// .build(); +// +// target.setCurrentConnection(newConnection, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); +// +// assertNotEquals(oldConnection, target.currentConnection); +// assertEquals(newConnection, target.currentConnection); +// assertEquals("new-host", target.currentHostSpec.getHost()); +// verify(oldConnection, times(1)).close(); +// } +// +// @Test +// public void testOldConnectionDisposeSuggestion() throws SQLException { +// when(pluginManager.notifyConnectionChanged(any(), any())) +// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.DISPOSE)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") +// .build(); +// +// target.setCurrentConnection(newConnection, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); +// +// assertNotEquals(oldConnection, target.currentConnection); +// assertEquals(newConnection, target.currentConnection); +// assertEquals("new-host", target.currentHostSpec.getHost()); +// verify(oldConnection, times(1)).close(); +// } +// +// @Test +// public void testOldConnectionPreserveSuggestion() throws SQLException { +// when(pluginManager.notifyConnectionChanged(any(), any())) +// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.PRESERVE)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") +// .build(); +// +// target.setCurrentConnection(newConnection, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); +// +// assertNotEquals(oldConnection, target.currentConnection); +// assertEquals(newConnection, target.currentConnection); +// assertEquals("new-host", target.currentHostSpec.getHost()); +// verify(oldConnection, times(0)).close(); +// } +// +// @Test +// public void testOldConnectionMixedSuggestion() throws SQLException { +// when(pluginManager.notifyConnectionChanged(any(), any())) +// .thenReturn( +// EnumSet.of( +// OldConnectionSuggestedAction.NO_OPINION, +// OldConnectionSuggestedAction.PRESERVE, +// OldConnectionSuggestedAction.DISPOSE)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") +// .build(); +// +// target.setCurrentConnection(newConnection, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); +// +// assertNotEquals(oldConnection, target.currentConnection); +// assertEquals(newConnection, target.currentConnection); +// assertEquals("new-host", target.currentHostSpec.getHost()); +// verify(oldConnection, times(0)).close(); +// } +// +// @Test +// public void testChangesNewConnectionNewHostNewPortNewRoleNewAvailability() throws SQLException { +// when(pluginManager.notifyConnectionChanged( +// argumentChanges.capture(), argumentSkipPlugin.capture())) +// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).build(); +// +// target.setCurrentConnection( +// newConnection, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("new-host").port(2000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) +// .build()); +// +// assertNull(argumentSkipPlugin.getValue()); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); +// } +// +// @Test +// public void testChangesNewConnectionNewRoleNewAvailability() throws SQLException { +// when(pluginManager.notifyConnectionChanged( +// argumentChanges.capture(), argumentSkipPlugin.capture())) +// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) +// .build(); +// +// target.setCurrentConnection(newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE) +// .build()); +// +// assertNull(argumentSkipPlugin.getValue()); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); +// } +// +// @Test +// public void testChangesNewConnection() throws SQLException { +// when(pluginManager.notifyConnectionChanged( +// argumentChanges.capture(), argumentSkipPlugin.capture())) +// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) +// .build(); +// +// target.setCurrentConnection( +// newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) +// .build()); +// +// assertNull(argumentSkipPlugin.getValue()); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); +// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); +// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); +// } +// +// @Test +// public void testChangesNoChanges() throws SQLException { +// when(pluginManager.notifyConnectionChanged(any(), any())).thenReturn( +// EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); +// +// PluginServiceImpl target = +// spy(new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.currentConnection = oldConnection; +// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(); +// +// target.setCurrentConnection( +// oldConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) +// .build()); +// +// verify(pluginManager, times(0)).notifyConnectionChanged(any(), any()); +// } +// +// @Test +// public void testSetNodeListAdded() throws SQLException { +// +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// when(hostListProvider.refresh()).thenReturn(Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build())); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.allHosts = new ArrayList<>(); +// target.hostListProvider = hostListProvider; +// +// target.refreshHostList(); +// +// assertEquals(1, target.getAllHosts().size()); +// assertEquals("hostA", target.getAllHosts().get(0).getHost()); +// verify(pluginManager, times(1)).notifyNodeListChanged(any()); +// +// Map> notifiedChanges = argumentChangesMap.getValue(); +// assertTrue(notifiedChanges.containsKey("hostA/")); +// EnumSet hostAChanges = notifiedChanges.get("hostA/"); +// assertEquals(1, hostAChanges.size()); +// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_ADDED)); +// } +// +// @Test +// public void testSetNodeListDeleted() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// when(hostListProvider.refresh()).thenReturn(Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build())); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.allHosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build()); +// target.hostListProvider = hostListProvider; +// +// target.refreshHostList(); +// +// assertEquals(1, target.getAllHosts().size()); +// assertEquals("hostB", target.getAllHosts().get(0).getHost()); +// verify(pluginManager, times(1)).notifyNodeListChanged(any()); +// +// Map> notifiedChanges = argumentChangesMap.getValue(); +// assertTrue(notifiedChanges.containsKey("hostA/")); +// EnumSet hostAChanges = notifiedChanges.get("hostA/"); +// assertEquals(1, hostAChanges.size()); +// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_DELETED)); +// } +// +// @Test +// public void testSetNodeListChanged() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// when(hostListProvider.refresh()).thenReturn( +// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") +// .port(HostSpec.NO_PORT).role(HostRole.READER).build())); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.WRITER).build()); +// target.hostListProvider = hostListProvider; +// +// target.refreshHostList(); +// +// assertEquals(1, target.getAllHosts().size()); +// assertEquals("hostA", target.getAllHosts().get(0).getHost()); +// verify(pluginManager, times(1)).notifyNodeListChanged(any()); +// +// Map> notifiedChanges = argumentChangesMap.getValue(); +// assertTrue(notifiedChanges.containsKey("hostA/")); +// EnumSet hostAChanges = notifiedChanges.get("hostA/"); +// assertEquals(2, hostAChanges.size()); +// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); +// assertTrue(hostAChanges.contains(NodeChangeOptions.PROMOTED_TO_READER)); +// } +// +// @Test +// public void testSetNodeListNoChanges() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(any()); +// +// when(hostListProvider.refresh()).thenReturn( +// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build())); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build()); +// target.hostListProvider = hostListProvider; +// +// target.refreshHostList(); +// +// assertEquals(1, target.getAllHosts().size()); +// assertEquals("hostA", target.getAllHosts().get(0).getHost()); +// verify(pluginManager, times(0)).notifyNodeListChanged(any()); +// } +// +// @Test +// public void testNodeAvailabilityNotChanged() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.allHosts = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) +// .build()); +// +// Set aliases = new HashSet<>(); +// aliases.add("hostA"); +// target.setAvailability(aliases, HostAvailability.AVAILABLE); +// +// assertEquals(1, target.getAllHosts().size()); +// assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); +// verify(pluginManager, never()).notifyNodeListChanged(any()); +// } +// +// @Test +// public void testNodeAvailabilityChanged_WentDown() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.allHosts = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) +// .build()); +// +// Set aliases = new HashSet<>(); +// aliases.add("hostA"); +// target.setAvailability(aliases, HostAvailability.NOT_AVAILABLE); +// +// assertEquals(1, target.getAllHosts().size()); +// assertEquals(HostAvailability.NOT_AVAILABLE, target.getAllHosts().get(0).getAvailability()); +// verify(pluginManager, times(1)).notifyNodeListChanged(any()); +// +// Map> notifiedChanges = argumentChangesMap.getValue(); +// assertTrue(notifiedChanges.containsKey("hostA/")); +// EnumSet hostAChanges = notifiedChanges.get("hostA/"); +// assertEquals(2, hostAChanges.size()); +// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); +// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_DOWN)); +// } +// +// @Test +// public void testNodeAvailabilityChanged_WentUp() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.allHosts = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) +// .build()); +// +// Set aliases = new HashSet<>(); +// aliases.add("hostA"); +// target.setAvailability(aliases, HostAvailability.AVAILABLE); +// +// assertEquals(1, target.getAllHosts().size()); +// assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); +// verify(pluginManager, times(1)).notifyNodeListChanged(any()); +// +// Map> notifiedChanges = argumentChangesMap.getValue(); +// assertTrue(notifiedChanges.containsKey("hostA/")); +// EnumSet hostAChanges = notifiedChanges.get("hostA/"); +// assertEquals(2, hostAChanges.size()); +// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); +// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); +// } +// +// @Test +// public void testNodeAvailabilityChanged_WentUp_ByAlias() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) +// .build(); +// hostA.addAlias("ip-10-10-10-10"); +// hostA.addAlias("hostA.custom.domain.com"); +// final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) +// .build(); +// hostB.addAlias("ip-10-10-10-10"); +// hostB.addAlias("hostB.custom.domain.com"); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// +// target.allHosts = Arrays.asList(hostA, hostB); +// +// Set aliases = new HashSet<>(); +// aliases.add("hostA.custom.domain.com"); +// target.setAvailability(aliases, HostAvailability.AVAILABLE); +// +// assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); +// assertEquals(HostAvailability.NOT_AVAILABLE, hostB.getAvailability()); +// verify(pluginManager, times(1)).notifyNodeListChanged(any()); +// +// Map> notifiedChanges = argumentChangesMap.getValue(); +// assertTrue(notifiedChanges.containsKey("hostA/")); +// EnumSet hostAChanges = notifiedChanges.get("hostA/"); +// assertEquals(2, hostAChanges.size()); +// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); +// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); +// } +// +// @Test +// public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQLException { +// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); +// +// final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) +// .build();; +// hostA.addAlias("ip-10-10-10-10"); +// hostA.addAlias("hostA.custom.domain.com"); +// final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) +// .build(); +// hostB.addAlias("ip-10-10-10-10"); +// hostB.addAlias("hostB.custom.domain.com"); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// +// target.allHosts = Arrays.asList(hostA, hostB); +// +// Set aliases = new HashSet<>(); +// aliases.add("ip-10-10-10-10"); +// target.setAvailability(aliases, HostAvailability.AVAILABLE); +// +// assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); +// assertEquals(HostAvailability.AVAILABLE, hostB.getAvailability()); +// verify(pluginManager, times(1)).notifyNodeListChanged(any()); +// +// Map> notifiedChanges = argumentChangesMap.getValue(); +// assertTrue(notifiedChanges.containsKey("hostA/")); +// EnumSet hostAChanges = notifiedChanges.get("hostA/"); +// assertEquals(2, hostAChanges.size()); +// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); +// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); +// +// assertTrue(notifiedChanges.containsKey("hostB/")); +// EnumSet hostBChanges = notifiedChanges.get("hostB/"); +// assertEquals(2, hostBChanges.size()); +// assertTrue(hostBChanges.contains(NodeChangeOptions.NODE_CHANGED)); +// assertTrue(hostBChanges.contains(NodeChangeOptions.WENT_UP)); +// } +// +// @Test +// void testRefreshHostList_withCachedHostAvailability() throws SQLException { +// final List newHostSpecs = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() +// ); +// final List newHostSpecs2 = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() +// ); +// final List expectedHostSpecs = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() +// ); +// final List expectedHostSpecs2 = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() +// ); +// +// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, +// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); +// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, +// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); +// when(hostListProvider.refresh()).thenReturn(newHostSpecs); +// when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs2); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// when(target.getHostListProvider()).thenReturn(hostListProvider); +// +// assertNotEquals(expectedHostSpecs, newHostSpecs); +// target.refreshHostList(); +// assertEquals(expectedHostSpecs, newHostSpecs); +// +// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, +// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); +// target.refreshHostList(newConnection); +// assertEquals(expectedHostSpecs2, newHostSpecs); +// } +// +// @Test +// void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { +// final List newHostSpecs = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() +// ); +// final List expectedHostSpecs = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() +// ); +// final List expectedHostSpecs2 = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) +// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() +// ); +// +// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, +// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); +// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, +// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); +// when(hostListProvider.forceRefresh()).thenReturn(newHostSpecs); +// when(hostListProvider.forceRefresh(newConnection)).thenReturn(newHostSpecs); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// when(target.getHostListProvider()).thenReturn(hostListProvider); +// +// assertNotEquals(expectedHostSpecs, newHostSpecs); +// target.forceRefreshHostList(); +// assertEquals(expectedHostSpecs, newHostSpecs); +// +// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, +// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); +// target.forceRefreshHostList(newConnection); +// assertEquals(expectedHostSpecs2, newHostSpecs); +// } +// +// @Test +// void testIdentifyConnectionWithNoAliases() throws SQLException { +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// when(target.getHostListProvider()).thenReturn(hostListProvider); +// +// when(target.getDialect()).thenReturn(new MysqlDialect()); +// assertNull(target.identifyConnection(newConnection)); +// } +// +// @Test +// void testIdentifyConnectionWithAliases() throws SQLException { +// final HostSpec expected = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("test") +// .build(); +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.hostListProvider = hostListProvider; +// when(target.getHostListProvider()).thenReturn(hostListProvider); +// when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(expected); +// +// when(target.getDialect()).thenReturn(new AuroraPgDialect()); +// final HostSpec actual = target.identifyConnection(newConnection); +// verify(target, never()).getCurrentHostSpec(); +// verify(hostListProvider).identifyConnection(newConnection); +// assertEquals(expected, actual); +// } +// +// @Test +// void testFillAliasesNonEmptyAliases() throws SQLException { +// final HostSpec oneAlias = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo") +// .build(); +// oneAlias.addAlias(oneAlias.asAlias()); +// +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// +// assertEquals(1, oneAlias.getAliases().size()); +// target.fillAliases(newConnection, oneAlias); +// // Fill aliases should return directly and no additional aliases should be added. +// assertEquals(1, oneAlias.getAliases().size()); +// } +// +// @ParameterizedTest +// @MethodSource("fillAliasesDialects") +// void testFillAliasesWithInstanceEndpoint(Dialect dialect, String[] expectedInstanceAliases) throws SQLException { +// final HostSpec empty = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo").build(); +// PluginServiceImpl target = spy( +// new PluginServiceImpl( +// servicesContainer, +// new ExceptionManager(), +// PROPERTIES, +// URL, +// DRIVER_PROTOCOL, +// dialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// sessionStateService)); +// target.hostListProvider = hostListProvider; +// when(target.getDialect()).thenReturn(dialect); +// when(resultSet.next()).thenReturn(true, false); // Result set contains 1 row. +// when(resultSet.getString(eq(1))).thenReturn("ip"); +// if (dialect instanceof AuroraPgDialect) { +// when(hostListProvider.identifyConnection(eq(newConnection))) +// .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance").build()); +// } +// +// target.fillAliases(newConnection, empty); +// +// final String[] aliases = empty.getAliases().toArray(new String[] {}); +// assertArrayEquals(expectedInstanceAliases, aliases); +// } +// +// private static Stream fillAliasesDialects() { +// return Stream.of( +// Arguments.of(new AuroraPgDialect(), new String[]{"instance", "foo", "ip"}), +// Arguments.of(new MysqlDialect(), new String[]{"foo", "ip"}) +// ); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java index 797d151be..f0abf31c2 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java @@ -1,629 +1,629 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.hostlistprovider; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atMostOnce; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.mysql.cj.exceptions.WrongArgumentException; -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.SQLSyntaxErrorException; -import java.sql.Statement; -import java.sql.Timestamp; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.events.EventPublisher; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.storage.TestStorageServiceImpl; - -class RdsHostListProviderTest { - private StorageService storageService; - private RdsHostListProvider rdsHostListProvider; - - @Mock private Connection mockConnection; - @Mock private Statement mockStatement; - @Mock private ResultSet mockResultSet; - @Mock private FullServicesContainer mockServicesContainer; - @Mock private PluginService mockPluginService; - @Mock private HostListProviderService mockHostListProviderService; - @Mock private EventPublisher mockEventPublisher; - @Mock Dialect mockTopologyAwareDialect; - @Captor private ArgumentCaptor queryCaptor; - - private AutoCloseable closeable; - private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("foo").port(1234).build(); - private final List hosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); - - @BeforeEach - void setUp() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - storageService = new TestStorageServiceImpl(mockEventPublisher); - when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); - when(mockServicesContainer.getStorageService()).thenReturn(storageService); - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); - when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); - when(mockConnection.createStatement()).thenReturn(mockStatement); - when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); - when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); - when(mockHostListProviderService.getHostSpecBuilder()) - .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); - when(mockHostListProviderService.getCurrentConnection()).thenReturn(mockConnection); - } - - @AfterEach - void tearDown() throws Exception { - RdsHostListProvider.clearAll(); - storageService.clearAll(); - closeable.close(); - } - - private RdsHostListProvider getRdsHostListProvider(String originalUrl) throws SQLException { - RdsHostListProvider provider = new RdsHostListProvider( - new Properties(), - originalUrl, - mockServicesContainer, - "foo", "bar", "baz"); - provider.init(); - return provider; - } - - @Test - void testGetTopology_returnCachedTopology() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("protocol://url/")); - - final List expected = hosts; - storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); - - final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); - assertEquals(expected, result.hosts); - assertEquals(2, result.hosts.size()); - verify(rdsHostListProvider, never()).queryForTopology(mockConnection); - } - - @Test - void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - rdsHostListProvider.isInitialized = true; - - storageService.set(rdsHostListProvider.clusterId, new Topology(hosts)); - - final List newHosts = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); - doReturn(newHosts).when(rdsHostListProvider).queryForTopology(mockConnection); - - final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); - verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); - assertEquals(1, result.hosts.size()); - assertEquals(newHosts, result.hosts); - } - - @Test - void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - rdsHostListProvider.clusterId = "cluster-id"; - rdsHostListProvider.isInitialized = true; - - final List expected = hosts; - storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); - - doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); - - final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); - verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); - assertEquals(2, result.hosts.size()); - assertEquals(expected, result.hosts); - } - - @Test - void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - rdsHostListProvider.clear(); - - doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); - - final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); - verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); - assertNotNull(result.hosts); - assertEquals( - Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), - result.hosts); - } - - @Test - void testQueryForTopology_withDifferentDriverProtocol() throws SQLException { - final List expectedMySQL = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("mysql").port(HostSpec.NO_PORT) - .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); - final List expectedPostgres = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("postgresql").port(HostSpec.NO_PORT) - .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); - when(mockResultSet.next()).thenReturn(true, false); - when(mockResultSet.getBoolean(eq(2))).thenReturn(true); - when(mockResultSet.getString(eq(1))).thenReturn("mysql"); - - - rdsHostListProvider = getRdsHostListProvider("mysql://url/"); - - List hosts = rdsHostListProvider.queryForTopology(mockConnection); - assertEquals(expectedMySQL, hosts); - - when(mockResultSet.next()).thenReturn(true, false); - when(mockResultSet.getString(eq(1))).thenReturn("postgresql"); - - rdsHostListProvider = getRdsHostListProvider("postgresql://url/"); - hosts = rdsHostListProvider.queryForTopology(mockConnection); - assertEquals(expectedPostgres, hosts); - } - - @Test - void testQueryForTopology_queryResultsInException() throws SQLException { - rdsHostListProvider = getRdsHostListProvider("protocol://url/"); - when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); - - assertThrows( - SQLException.class, - () -> rdsHostListProvider.queryForTopology(mockConnection)); - } - - @Test - void testGetCachedTopology_returnStoredTopology() throws SQLException { - rdsHostListProvider = getRdsHostListProvider("jdbc:someprotocol://url"); - - final List expected = hosts; - storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); - - final List result = rdsHostListProvider.getStoredTopology(); - assertEquals(expected, result); - } - - @Test - void testTopologyCache_NoSuggestedClusterId() throws SQLException { - RdsHostListProvider.clearAll(); - - RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.domain.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); - - doReturn(topologyClusterA) - .when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - final List topologyProvider1 = provider1.refresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-b.domain.com/")); - provider2.init(); - assertNull(provider2.getStoredTopology()); - - final List topologyClusterB = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); - doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); - - final List topologyProvider2 = provider2.refresh(mock(Connection.class)); - assertEquals(topologyClusterB, topologyProvider2); - - assertEquals(2, storageService.size(Topology.class)); - } - - @Test - void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { - RdsHostListProvider.clearAll(); - - RdsHostListProvider provider1 = - Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build()); - - doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - final List topologyProvider1 = provider1.refresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - RdsHostListProvider provider2 = - Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider2.init(); - - assertEquals(provider1.clusterId, provider2.clusterId); - assertTrue(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - - final List topologyProvider2 = provider2.refresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider2); - - assertEquals(1, storageService.size(Topology.class)); - } - - @Test - void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { - RdsHostListProvider.clearAll(); - - RdsHostListProvider provider1 = - Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build()); - - doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - final List topologyProvider1 = provider1.refresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - RdsHostListProvider provider2 = - Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); - provider2.init(); - - assertEquals(provider1.clusterId, provider2.clusterId); - assertTrue(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - - final List topologyProvider2 = provider2.refresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider2); - - assertEquals(1, storageService.size(Topology.class)); - } - - @Test - void testTopologyCache_AcceptSuggestion() throws SQLException { - RdsHostListProvider.clearAll(); - - RdsHostListProvider provider1 = - Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build()); - - doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - List topologyProvider1 = provider1.refresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - // RdsHostListProvider.logCache(); - - RdsHostListProvider provider2 = - Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider2.init(); - - doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); - - final List topologyProvider2 = provider2.refresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider2); - - assertNotEquals(provider1.clusterId, provider2.clusterId); - assertFalse(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - assertEquals(2, storageService.size(Topology.class)); - assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", - RdsHostListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); - - // RdsHostListProvider.logCache(); - - topologyProvider1 = provider1.forceRefresh(mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - assertEquals(provider1.clusterId, provider2.clusterId); - assertTrue(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - - // RdsHostListProvider.logCache(); - } - - @Test - void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - - when(mockResultSet.next()).thenReturn(false); - assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); - - when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); - assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); - } - - @Test - void testIdentifyConnectionNullTopology() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - rdsHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("?.pattern").build(); - - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); - doReturn(null).when(rdsHostListProvider).refresh(mockConnection); - doReturn(null).when(rdsHostListProvider).forceRefresh(mockConnection); - - assertNull(rdsHostListProvider.identifyConnection(mockConnection)); - } - - @Test - void testIdentifyConnectionHostNotInTopology() throws SQLException { - final List cachedTopology = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build()); - - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); - doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); - doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); - - assertNull(rdsHostListProvider.identifyConnection(mockConnection)); - } - - @Test - void testIdentifyConnectionHostInTopology() throws SQLException { - final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(); - expectedHost.setHostId("instance-a-1"); - final List cachedTopology = Collections.singletonList(expectedHost); - - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); - doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); - doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); - - final HostSpec actual = rdsHostListProvider.identifyConnection(mockConnection); - assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); - assertEquals("instance-a-1", actual.getHostId()); - } - - @Test - void testGetTopology_StaleRecord() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - rdsHostListProvider.isInitialized = true; - - final String hostName1 = "hostName1"; - final String hostName2 = "hostName2"; - final Double cpuUtilization = 11.1D; - final Double nodeLag = 0.123D; - final Timestamp firstTimestamp = Timestamp.from(Instant.now()); - final Timestamp secondTimestamp = new Timestamp(firstTimestamp.getTime() + 100); - when(mockResultSet.next()).thenReturn(true, true, false); - when(mockResultSet.getString(1)).thenReturn(hostName1).thenReturn(hostName2); - when(mockResultSet.getBoolean(2)).thenReturn(true).thenReturn(true); - when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization).thenReturn(cpuUtilization); - when(mockResultSet.getDouble(4)).thenReturn(nodeLag).thenReturn(nodeLag); - when(mockResultSet.getTimestamp(5)).thenReturn(firstTimestamp).thenReturn(secondTimestamp); - long weight = Math.round(nodeLag) * 100L + Math.round(cpuUtilization); - final HostSpec expectedWriter = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host(hostName2) - .port(-1) - .role(HostRole.WRITER) - .availability(HostAvailability.AVAILABLE) - .weight(weight) - .lastUpdateTime(secondTimestamp) - .build(); - - final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); - verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); - assertEquals(1, result.hosts.size()); - assertEquals(expectedWriter, result.hosts.get(0)); - } - - @Test - void testGetTopology_InvalidLastUpdatedTimestamp() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - rdsHostListProvider.isInitialized = true; - - final String hostName = "hostName"; - final Double cpuUtilization = 11.1D; - final Double nodeLag = 0.123D; - when(mockResultSet.next()).thenReturn(true, false); - when(mockResultSet.getString(1)).thenReturn(hostName); - when(mockResultSet.getBoolean(2)).thenReturn(true); - when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization); - when(mockResultSet.getDouble(4)).thenReturn(nodeLag); - when(mockResultSet.getTimestamp(5)).thenThrow(WrongArgumentException.class); - - final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); - verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); - - final String expectedLastUpdatedTimeStampRounded = Timestamp.from(Instant.now()).toString().substring(0, 16); - assertEquals(1, result.hosts.size()); - assertEquals( - expectedLastUpdatedTimeStampRounded, - result.hosts.get(0).getLastUpdateTime().toString().substring(0, 16)); - } - - @Test - void testGetTopology_returnsLatestWriter() throws SQLException { - rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); - rdsHostListProvider.isInitialized = true; - - HostSpec expectedWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("expectedWriterHost") - .role(HostRole.WRITER) - .lastUpdateTime(Timestamp.valueOf("3000-01-01 00:00:00")) - .build(); - - HostSpec unexpectedWriterHost0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("unexpectedWriterHost0") - .role(HostRole.WRITER) - .lastUpdateTime(Timestamp.valueOf("1000-01-01 00:00:00")) - .build(); - - HostSpec unexpectedWriterHost1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("unexpectedWriterHost1") - .role(HostRole.WRITER) - .lastUpdateTime(Timestamp.valueOf("2000-01-01 00:00:00")) - .build(); - - HostSpec unexpectedWriterHostWithNullLastUpdateTime0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("unexpectedWriterHostWithNullLastUpdateTime0") - .role(HostRole.WRITER) - .lastUpdateTime(null) - .build(); - - HostSpec unexpectedWriterHostWithNullLastUpdateTime1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("unexpectedWriterHostWithNullLastUpdateTime1") - .role(HostRole.WRITER) - .lastUpdateTime(null) - .build(); - - when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); - - when(mockResultSet.getString(1)).thenReturn( - unexpectedWriterHostWithNullLastUpdateTime0.getHost(), - unexpectedWriterHost0.getHost(), - expectedWriterHost.getHost(), - unexpectedWriterHost1.getHost(), - unexpectedWriterHostWithNullLastUpdateTime1.getHost()); - when(mockResultSet.getBoolean(2)).thenReturn(true, true, true, true, true); - when(mockResultSet.getFloat(3)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); - when(mockResultSet.getFloat(4)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); - when(mockResultSet.getTimestamp(5)).thenReturn( - unexpectedWriterHostWithNullLastUpdateTime0.getLastUpdateTime(), - unexpectedWriterHost0.getLastUpdateTime(), - expectedWriterHost.getLastUpdateTime(), - unexpectedWriterHost1.getLastUpdateTime(), - unexpectedWriterHostWithNullLastUpdateTime1.getLastUpdateTime() - ); - - final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); - verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); - - assertEquals(expectedWriterHost.getHost(), result.hosts.get(0).getHost()); - } - - @Test - void testClusterUrlUsedAsDefaultClusterId() throws SQLException { - String readerClusterUrl = "mycluster.cluster-ro-XYZ.us-east-1.rds.amazonaws.com"; - String expectedClusterId = "mycluster.cluster-XYZ.us-east-1.rds.amazonaws.com:1234"; - String connectionString = "jdbc:someprotocol://" + readerClusterUrl + ":1234/test"; - RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider(connectionString)); - assertEquals(expectedClusterId, provider1.getClusterId()); - - List mockTopology = - Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build()); - doReturn(mockTopology).when(provider1).queryForTopology(any(Connection.class)); - provider1.refresh(); - assertEquals(mockTopology, provider1.getStoredTopology()); - verify(provider1, times(1)).queryForTopology(mockConnection); - - RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider(connectionString)); - assertEquals(expectedClusterId, provider2.getClusterId()); - assertEquals(mockTopology, provider2.getStoredTopology()); - verify(provider2, never()).queryForTopology(mockConnection); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.hostlistprovider; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertNotEquals; +// import static org.junit.jupiter.api.Assertions.assertNotNull; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.atMostOnce; +// import static org.mockito.Mockito.doAnswer; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.mock; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import com.mysql.cj.exceptions.WrongArgumentException; +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.SQLSyntaxErrorException; +// import java.sql.Statement; +// import java.sql.Timestamp; +// import java.time.Instant; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.List; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.ArgumentCaptor; +// import org.mockito.Captor; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.events.EventPublisher; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.storage.TestStorageServiceImpl; +// +// class RdsHostListProviderTest { +// private StorageService storageService; +// private RdsHostListProvider rdsHostListProvider; +// +// @Mock private Connection mockConnection; +// @Mock private Statement mockStatement; +// @Mock private ResultSet mockResultSet; +// @Mock private FullServicesContainer mockServicesContainer; +// @Mock private PluginService mockPluginService; +// @Mock private HostListProviderService mockHostListProviderService; +// @Mock private EventPublisher mockEventPublisher; +// @Mock Dialect mockTopologyAwareDialect; +// @Captor private ArgumentCaptor queryCaptor; +// +// private AutoCloseable closeable; +// private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("foo").port(1234).build(); +// private final List hosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); +// +// @BeforeEach +// void setUp() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// storageService = new TestStorageServiceImpl(mockEventPublisher); +// when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); +// when(mockServicesContainer.getStorageService()).thenReturn(storageService); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); +// when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); +// when(mockConnection.createStatement()).thenReturn(mockStatement); +// when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); +// when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); +// when(mockHostListProviderService.getHostSpecBuilder()) +// .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); +// when(mockHostListProviderService.getCurrentConnection()).thenReturn(mockConnection); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// RdsHostListProvider.clearAll(); +// storageService.clearAll(); +// closeable.close(); +// } +// +// private RdsHostListProvider getRdsHostListProvider(String originalUrl) throws SQLException { +// RdsHostListProvider provider = new RdsHostListProvider( +// new Properties(), +// originalUrl, +// mockServicesContainer, +// "foo", "bar", "baz"); +// provider.init(); +// return provider; +// } +// +// @Test +// void testGetTopology_returnCachedTopology() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("protocol://url/")); +// +// final List expected = hosts; +// storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); +// +// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); +// assertEquals(expected, result.hosts); +// assertEquals(2, result.hosts.size()); +// verify(rdsHostListProvider, never()).queryForTopology(mockConnection); +// } +// +// @Test +// void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// rdsHostListProvider.isInitialized = true; +// +// storageService.set(rdsHostListProvider.clusterId, new Topology(hosts)); +// +// final List newHosts = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); +// doReturn(newHosts).when(rdsHostListProvider).queryForTopology(mockConnection); +// +// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); +// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// assertEquals(1, result.hosts.size()); +// assertEquals(newHosts, result.hosts); +// } +// +// @Test +// void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// rdsHostListProvider.clusterId = "cluster-id"; +// rdsHostListProvider.isInitialized = true; +// +// final List expected = hosts; +// storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); +// +// doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); +// +// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); +// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// assertEquals(2, result.hosts.size()); +// assertEquals(expected, result.hosts); +// } +// +// @Test +// void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// rdsHostListProvider.clear(); +// +// doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); +// +// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); +// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// assertNotNull(result.hosts); +// assertEquals( +// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), +// result.hosts); +// } +// +// @Test +// void testQueryForTopology_withDifferentDriverProtocol() throws SQLException { +// final List expectedMySQL = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("mysql").port(HostSpec.NO_PORT) +// .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); +// final List expectedPostgres = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("postgresql").port(HostSpec.NO_PORT) +// .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); +// when(mockResultSet.next()).thenReturn(true, false); +// when(mockResultSet.getBoolean(eq(2))).thenReturn(true); +// when(mockResultSet.getString(eq(1))).thenReturn("mysql"); +// +// +// rdsHostListProvider = getRdsHostListProvider("mysql://url/"); +// +// List hosts = rdsHostListProvider.queryForTopology(mockConnection); +// assertEquals(expectedMySQL, hosts); +// +// when(mockResultSet.next()).thenReturn(true, false); +// when(mockResultSet.getString(eq(1))).thenReturn("postgresql"); +// +// rdsHostListProvider = getRdsHostListProvider("postgresql://url/"); +// hosts = rdsHostListProvider.queryForTopology(mockConnection); +// assertEquals(expectedPostgres, hosts); +// } +// +// @Test +// void testQueryForTopology_queryResultsInException() throws SQLException { +// rdsHostListProvider = getRdsHostListProvider("protocol://url/"); +// when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); +// +// assertThrows( +// SQLException.class, +// () -> rdsHostListProvider.queryForTopology(mockConnection)); +// } +// +// @Test +// void testGetCachedTopology_returnStoredTopology() throws SQLException { +// rdsHostListProvider = getRdsHostListProvider("jdbc:someprotocol://url"); +// +// final List expected = hosts; +// storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); +// +// final List result = rdsHostListProvider.getStoredTopology(); +// assertEquals(expected, result); +// } +// +// @Test +// void testTopologyCache_NoSuggestedClusterId() throws SQLException { +// RdsHostListProvider.clearAll(); +// +// RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.domain.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); +// +// doReturn(topologyClusterA) +// .when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// final List topologyProvider1 = provider1.refresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-b.domain.com/")); +// provider2.init(); +// assertNull(provider2.getStoredTopology()); +// +// final List topologyClusterB = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); +// doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); +// +// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); +// assertEquals(topologyClusterB, topologyProvider2); +// +// assertEquals(2, storageService.size(Topology.class)); +// } +// +// @Test +// void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { +// RdsHostListProvider.clearAll(); +// +// RdsHostListProvider provider1 = +// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build()); +// +// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// final List topologyProvider1 = provider1.refresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// RdsHostListProvider provider2 = +// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider2.init(); +// +// assertEquals(provider1.clusterId, provider2.clusterId); +// assertTrue(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// +// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider2); +// +// assertEquals(1, storageService.size(Topology.class)); +// } +// +// @Test +// void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { +// RdsHostListProvider.clearAll(); +// +// RdsHostListProvider provider1 = +// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build()); +// +// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// final List topologyProvider1 = provider1.refresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// RdsHostListProvider provider2 = +// Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); +// provider2.init(); +// +// assertEquals(provider1.clusterId, provider2.clusterId); +// assertTrue(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// +// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider2); +// +// assertEquals(1, storageService.size(Topology.class)); +// } +// +// @Test +// void testTopologyCache_AcceptSuggestion() throws SQLException { +// RdsHostListProvider.clearAll(); +// +// RdsHostListProvider provider1 = +// Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build()); +// +// doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// List topologyProvider1 = provider1.refresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// // RdsHostListProvider.logCache(); +// +// RdsHostListProvider provider2 = +// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider2.init(); +// +// doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); +// +// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider2); +// +// assertNotEquals(provider1.clusterId, provider2.clusterId); +// assertFalse(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// assertEquals(2, storageService.size(Topology.class)); +// assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", +// RdsHostListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); +// +// // RdsHostListProvider.logCache(); +// +// topologyProvider1 = provider1.forceRefresh(mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// assertEquals(provider1.clusterId, provider2.clusterId); +// assertTrue(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// +// // RdsHostListProvider.logCache(); +// } +// +// @Test +// void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// +// when(mockResultSet.next()).thenReturn(false); +// assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); +// +// when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); +// assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); +// } +// +// @Test +// void testIdentifyConnectionNullTopology() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// rdsHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("?.pattern").build(); +// +// when(mockResultSet.next()).thenReturn(true); +// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); +// doReturn(null).when(rdsHostListProvider).refresh(mockConnection); +// doReturn(null).when(rdsHostListProvider).forceRefresh(mockConnection); +// +// assertNull(rdsHostListProvider.identifyConnection(mockConnection)); +// } +// +// @Test +// void testIdentifyConnectionHostNotInTopology() throws SQLException { +// final List cachedTopology = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build()); +// +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// when(mockResultSet.next()).thenReturn(true); +// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); +// doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); +// doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); +// +// assertNull(rdsHostListProvider.identifyConnection(mockConnection)); +// } +// +// @Test +// void testIdentifyConnectionHostInTopology() throws SQLException { +// final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(); +// expectedHost.setHostId("instance-a-1"); +// final List cachedTopology = Collections.singletonList(expectedHost); +// +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// when(mockResultSet.next()).thenReturn(true); +// when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); +// doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); +// doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); +// +// final HostSpec actual = rdsHostListProvider.identifyConnection(mockConnection); +// assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); +// assertEquals("instance-a-1", actual.getHostId()); +// } +// +// @Test +// void testGetTopology_StaleRecord() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// rdsHostListProvider.isInitialized = true; +// +// final String hostName1 = "hostName1"; +// final String hostName2 = "hostName2"; +// final Double cpuUtilization = 11.1D; +// final Double nodeLag = 0.123D; +// final Timestamp firstTimestamp = Timestamp.from(Instant.now()); +// final Timestamp secondTimestamp = new Timestamp(firstTimestamp.getTime() + 100); +// when(mockResultSet.next()).thenReturn(true, true, false); +// when(mockResultSet.getString(1)).thenReturn(hostName1).thenReturn(hostName2); +// when(mockResultSet.getBoolean(2)).thenReturn(true).thenReturn(true); +// when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization).thenReturn(cpuUtilization); +// when(mockResultSet.getDouble(4)).thenReturn(nodeLag).thenReturn(nodeLag); +// when(mockResultSet.getTimestamp(5)).thenReturn(firstTimestamp).thenReturn(secondTimestamp); +// long weight = Math.round(nodeLag) * 100L + Math.round(cpuUtilization); +// final HostSpec expectedWriter = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host(hostName2) +// .port(-1) +// .role(HostRole.WRITER) +// .availability(HostAvailability.AVAILABLE) +// .weight(weight) +// .lastUpdateTime(secondTimestamp) +// .build(); +// +// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); +// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// assertEquals(1, result.hosts.size()); +// assertEquals(expectedWriter, result.hosts.get(0)); +// } +// +// @Test +// void testGetTopology_InvalidLastUpdatedTimestamp() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// rdsHostListProvider.isInitialized = true; +// +// final String hostName = "hostName"; +// final Double cpuUtilization = 11.1D; +// final Double nodeLag = 0.123D; +// when(mockResultSet.next()).thenReturn(true, false); +// when(mockResultSet.getString(1)).thenReturn(hostName); +// when(mockResultSet.getBoolean(2)).thenReturn(true); +// when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization); +// when(mockResultSet.getDouble(4)).thenReturn(nodeLag); +// when(mockResultSet.getTimestamp(5)).thenThrow(WrongArgumentException.class); +// +// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); +// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// +// final String expectedLastUpdatedTimeStampRounded = Timestamp.from(Instant.now()).toString().substring(0, 16); +// assertEquals(1, result.hosts.size()); +// assertEquals( +// expectedLastUpdatedTimeStampRounded, +// result.hosts.get(0).getLastUpdateTime().toString().substring(0, 16)); +// } +// +// @Test +// void testGetTopology_returnsLatestWriter() throws SQLException { +// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); +// rdsHostListProvider.isInitialized = true; +// +// HostSpec expectedWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("expectedWriterHost") +// .role(HostRole.WRITER) +// .lastUpdateTime(Timestamp.valueOf("3000-01-01 00:00:00")) +// .build(); +// +// HostSpec unexpectedWriterHost0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("unexpectedWriterHost0") +// .role(HostRole.WRITER) +// .lastUpdateTime(Timestamp.valueOf("1000-01-01 00:00:00")) +// .build(); +// +// HostSpec unexpectedWriterHost1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("unexpectedWriterHost1") +// .role(HostRole.WRITER) +// .lastUpdateTime(Timestamp.valueOf("2000-01-01 00:00:00")) +// .build(); +// +// HostSpec unexpectedWriterHostWithNullLastUpdateTime0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("unexpectedWriterHostWithNullLastUpdateTime0") +// .role(HostRole.WRITER) +// .lastUpdateTime(null) +// .build(); +// +// HostSpec unexpectedWriterHostWithNullLastUpdateTime1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("unexpectedWriterHostWithNullLastUpdateTime1") +// .role(HostRole.WRITER) +// .lastUpdateTime(null) +// .build(); +// +// when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); +// +// when(mockResultSet.getString(1)).thenReturn( +// unexpectedWriterHostWithNullLastUpdateTime0.getHost(), +// unexpectedWriterHost0.getHost(), +// expectedWriterHost.getHost(), +// unexpectedWriterHost1.getHost(), +// unexpectedWriterHostWithNullLastUpdateTime1.getHost()); +// when(mockResultSet.getBoolean(2)).thenReturn(true, true, true, true, true); +// when(mockResultSet.getFloat(3)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); +// when(mockResultSet.getFloat(4)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); +// when(mockResultSet.getTimestamp(5)).thenReturn( +// unexpectedWriterHostWithNullLastUpdateTime0.getLastUpdateTime(), +// unexpectedWriterHost0.getLastUpdateTime(), +// expectedWriterHost.getLastUpdateTime(), +// unexpectedWriterHost1.getLastUpdateTime(), +// unexpectedWriterHostWithNullLastUpdateTime1.getLastUpdateTime() +// ); +// +// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); +// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// +// assertEquals(expectedWriterHost.getHost(), result.hosts.get(0).getHost()); +// } +// +// @Test +// void testClusterUrlUsedAsDefaultClusterId() throws SQLException { +// String readerClusterUrl = "mycluster.cluster-ro-XYZ.us-east-1.rds.amazonaws.com"; +// String expectedClusterId = "mycluster.cluster-XYZ.us-east-1.rds.amazonaws.com:1234"; +// String connectionString = "jdbc:someprotocol://" + readerClusterUrl + ":1234/test"; +// RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider(connectionString)); +// assertEquals(expectedClusterId, provider1.getClusterId()); +// +// List mockTopology = +// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build()); +// doReturn(mockTopology).when(provider1).queryForTopology(any(Connection.class)); +// provider1.refresh(); +// assertEquals(mockTopology, provider1.getStoredTopology()); +// verify(provider1, times(1)).queryForTopology(mockConnection); +// +// RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider(connectionString)); +// assertEquals(expectedClusterId, provider2.getClusterId()); +// assertEquals(mockTopology, provider2.getStoredTopology()); +// verify(provider2, never()).queryForTopology(mockConnection); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java index df6d6ee50..db5e10c62 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java @@ -1,470 +1,470 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.hostlistprovider; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atMostOnce; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.SQLSyntaxErrorException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.events.EventPublisher; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.storage.TestStorageServiceImpl; - -class RdsMultiAzDbClusterListProviderTest { - private StorageService storageService; - private RdsMultiAzDbClusterListProvider rdsMazDbClusterHostListProvider; - - @Mock private Connection mockConnection; - @Mock private Statement mockStatement; - @Mock private ResultSet mockResultSet; - @Mock private FullServicesContainer mockServicesContainer; - @Mock private PluginService mockPluginService; - @Mock private HostListProviderService mockHostListProviderService; - @Mock private EventPublisher mockEventPublisher; - @Mock Dialect mockTopologyAwareDialect; - @Captor private ArgumentCaptor queryCaptor; - - private AutoCloseable closeable; - private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("foo").port(1234).build(); - private final List hosts = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); - - @BeforeEach - void setUp() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - storageService = new TestStorageServiceImpl(mockEventPublisher); - when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); - when(mockServicesContainer.getStorageService()).thenReturn(storageService); - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); - when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); - when(mockConnection.createStatement()).thenReturn(mockStatement); - when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); - when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); - when(mockHostListProviderService.getHostSpecBuilder()) - .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); - } - - @AfterEach - void tearDown() throws Exception { - RdsMultiAzDbClusterListProvider.clearAll(); - storageService.clearAll(); - closeable.close(); - } - - private RdsMultiAzDbClusterListProvider getRdsMazDbClusterHostListProvider(String originalUrl) throws SQLException { - RdsMultiAzDbClusterListProvider provider = new RdsMultiAzDbClusterListProvider( - new Properties(), - originalUrl, - mockServicesContainer, - "foo", - "bar", - "baz", - "fang", - "li"); - provider.init(); - // provider.clusterId = "cluster-id"; - return provider; - } - - @Test - void testGetTopology_returnCachedTopology() throws SQLException { - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("protocol://url/")); - final List expected = hosts; - storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); - - final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); - assertEquals(expected, result.hosts); - assertEquals(2, result.hosts.size()); - verify(rdsMazDbClusterHostListProvider, never()).queryForTopology(mockConnection); - } - - @Test - void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); - rdsMazDbClusterHostListProvider.isInitialized = true; - - storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(hosts)); - - final List newHosts = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); - doReturn(newHosts).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); - - final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); - verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); - assertEquals(1, result.hosts.size()); - assertEquals(newHosts, result.hosts); - } - - @Test - void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); - rdsMazDbClusterHostListProvider.clusterId = "cluster-id"; - rdsMazDbClusterHostListProvider.isInitialized = true; - - final List expected = hosts; - storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); - - doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); - - final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); - verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); - assertEquals(2, result.hosts.size()); - assertEquals(expected, result.hosts); - } - - @Test - void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); - rdsMazDbClusterHostListProvider.clear(); - - doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); - - final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); - verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); - assertNotNull(result.hosts); - assertEquals( - Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), - result.hosts); - } - - @Test - void testQueryForTopology_queryResultsInException() throws SQLException { - rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("protocol://url/"); - when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); - - assertThrows( - SQLException.class, - () -> rdsMazDbClusterHostListProvider.queryForTopology(mockConnection)); - } - - @Test - void testGetCachedTopology_returnCachedTopology() throws SQLException { - rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url"); - - final List expected = hosts; - storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); - - final List result = rdsMazDbClusterHostListProvider.getStoredTopology(); - assertEquals(expected, result); - } - - @Test - void testTopologyCache_NoSuggestedClusterId() throws SQLException { - RdsMultiAzDbClusterListProvider.clearAll(); - - RdsMultiAzDbClusterListProvider provider1 = - Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-a.domain.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); - - doReturn(topologyClusterA) - .when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - RdsMultiAzDbClusterListProvider provider2 = - Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-b.domain.com/")); - provider2.init(); - assertNull(provider2.getStoredTopology()); - - final List topologyClusterB = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); - doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); - - final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterB, topologyProvider2); - - assertEquals(2, storageService.size(Topology.class)); - } - - @Test - void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { - RdsMultiAzDbClusterListProvider.clearAll(); - - RdsMultiAzDbClusterListProvider provider1 = - Mockito.spy(getRdsMazDbClusterHostListProvider( - "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build()); - - doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - RdsMultiAzDbClusterListProvider provider2 = - Mockito.spy(getRdsMazDbClusterHostListProvider( - "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider2.init(); - - assertEquals(provider1.clusterId, provider2.clusterId); - assertTrue(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - - final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider2); - - assertEquals(1, storageService.size(Topology.class)); - } - - @Test - void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { - RdsMultiAzDbClusterListProvider.clearAll(); - - RdsMultiAzDbClusterListProvider provider1 = - Mockito.spy(getRdsMazDbClusterHostListProvider( - "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build()); - - doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - RdsMultiAzDbClusterListProvider provider2 = - Mockito.spy(getRdsMazDbClusterHostListProvider( - "jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); - provider2.init(); - - assertEquals(provider1.clusterId, provider2.clusterId); - assertTrue(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - - final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider2); - - assertEquals(1, storageService.size(Topology.class)); - } - - @Test - void testTopologyCache_AcceptSuggestion() throws SQLException { - RdsMultiAzDbClusterListProvider.clearAll(); - - RdsMultiAzDbClusterListProvider provider1 = - Mockito.spy(getRdsMazDbClusterHostListProvider( - "jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); - provider1.init(); - final List topologyClusterA = Arrays.asList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build(), - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.READER) - .build()); - - doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); - - assertEquals(0, storageService.size(Topology.class)); - - List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - - // RdsMultiAzDbClusterListProvider.logCache(); - - RdsMultiAzDbClusterListProvider provider2 = - Mockito.spy(getRdsMazDbClusterHostListProvider( - "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); - provider2.init(); - - doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); - - final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider2); - - assertNotEquals(provider1.clusterId, provider2.clusterId); - assertFalse(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - assertEquals(2, storageService.size(Topology.class)); - assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", - RdsMultiAzDbClusterListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); - - // RdsMultiAzDbClusterListProvider.logCache(); - - topologyProvider1 = provider1.forceRefresh(Mockito.mock(Connection.class)); - assertEquals(topologyClusterA, topologyProvider1); - assertEquals(provider1.clusterId, provider2.clusterId); - assertTrue(provider1.isPrimaryClusterId); - assertTrue(provider2.isPrimaryClusterId); - - // RdsMultiAzDbClusterListProvider.logCache(); - } - - @Test - void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); - - when(mockResultSet.next()).thenReturn(false); - assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); - - when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); - assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); - } - - @Test - void testIdentifyConnectionNullTopology() throws SQLException { - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); - rdsMazDbClusterHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("?.pattern").build(); - - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); - doReturn(null).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); - doReturn(null).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); - - assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); - } - - @Test - void testIdentifyConnectionHostNotInTopology() throws SQLException { - final List cachedTopology = Collections.singletonList( - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build()); - - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); - doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); - doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); - - assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); - } - - @Test - void testIdentifyConnectionHostInTopology() throws SQLException { - final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") - .hostId("instance-a-1") - .port(HostSpec.NO_PORT) - .role(HostRole.WRITER) - .build(); - final List cachedTopology = Collections.singletonList(expectedHost); - - rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); - doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); - doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); - - final HostSpec actual = rdsMazDbClusterHostListProvider.identifyConnection(mockConnection); - assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); - assertEquals("instance-a-1", actual.getHostId()); - } - -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.hostlistprovider; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertFalse; +// import static org.junit.jupiter.api.Assertions.assertNotEquals; +// import static org.junit.jupiter.api.Assertions.assertNotNull; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.junit.jupiter.api.Assertions.assertTrue; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.atMostOnce; +// import static org.mockito.Mockito.doAnswer; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.SQLSyntaxErrorException; +// import java.sql.Statement; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.List; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.ArgumentCaptor; +// import org.mockito.Captor; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.events.EventPublisher; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.storage.TestStorageServiceImpl; +// +// class RdsMultiAzDbClusterListProviderTest { +// private StorageService storageService; +// private RdsMultiAzDbClusterListProvider rdsMazDbClusterHostListProvider; +// +// @Mock private Connection mockConnection; +// @Mock private Statement mockStatement; +// @Mock private ResultSet mockResultSet; +// @Mock private FullServicesContainer mockServicesContainer; +// @Mock private PluginService mockPluginService; +// @Mock private HostListProviderService mockHostListProviderService; +// @Mock private EventPublisher mockEventPublisher; +// @Mock Dialect mockTopologyAwareDialect; +// @Captor private ArgumentCaptor queryCaptor; +// +// private AutoCloseable closeable; +// private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("foo").port(1234).build(); +// private final List hosts = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); +// +// @BeforeEach +// void setUp() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// storageService = new TestStorageServiceImpl(mockEventPublisher); +// when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); +// when(mockServicesContainer.getStorageService()).thenReturn(storageService); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); +// when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); +// when(mockConnection.createStatement()).thenReturn(mockStatement); +// when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); +// when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); +// when(mockHostListProviderService.getHostSpecBuilder()) +// .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// RdsMultiAzDbClusterListProvider.clearAll(); +// storageService.clearAll(); +// closeable.close(); +// } +// +// private RdsMultiAzDbClusterListProvider getRdsMazDbClusterHostListProvider(String originalUrl) throws SQLException { +// RdsMultiAzDbClusterListProvider provider = new RdsMultiAzDbClusterListProvider( +// new Properties(), +// originalUrl, +// mockServicesContainer, +// "foo", +// "bar", +// "baz", +// "fang", +// "li"); +// provider.init(); +// // provider.clusterId = "cluster-id"; +// return provider; +// } +// +// @Test +// void testGetTopology_returnCachedTopology() throws SQLException { +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("protocol://url/")); +// final List expected = hosts; +// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); +// +// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); +// assertEquals(expected, result.hosts); +// assertEquals(2, result.hosts.size()); +// verify(rdsMazDbClusterHostListProvider, never()).queryForTopology(mockConnection); +// } +// +// @Test +// void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); +// rdsMazDbClusterHostListProvider.isInitialized = true; +// +// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(hosts)); +// +// final List newHosts = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); +// doReturn(newHosts).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); +// +// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); +// verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// assertEquals(1, result.hosts.size()); +// assertEquals(newHosts, result.hosts); +// } +// +// @Test +// void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); +// rdsMazDbClusterHostListProvider.clusterId = "cluster-id"; +// rdsMazDbClusterHostListProvider.isInitialized = true; +// +// final List expected = hosts; +// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); +// +// doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); +// +// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); +// verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// assertEquals(2, result.hosts.size()); +// assertEquals(expected, result.hosts); +// } +// +// @Test +// void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); +// rdsMazDbClusterHostListProvider.clear(); +// +// doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); +// +// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); +// verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); +// assertNotNull(result.hosts); +// assertEquals( +// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), +// result.hosts); +// } +// +// @Test +// void testQueryForTopology_queryResultsInException() throws SQLException { +// rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("protocol://url/"); +// when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); +// +// assertThrows( +// SQLException.class, +// () -> rdsMazDbClusterHostListProvider.queryForTopology(mockConnection)); +// } +// +// @Test +// void testGetCachedTopology_returnCachedTopology() throws SQLException { +// rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url"); +// +// final List expected = hosts; +// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); +// +// final List result = rdsMazDbClusterHostListProvider.getStoredTopology(); +// assertEquals(expected, result); +// } +// +// @Test +// void testTopologyCache_NoSuggestedClusterId() throws SQLException { +// RdsMultiAzDbClusterListProvider.clearAll(); +// +// RdsMultiAzDbClusterListProvider provider1 = +// Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-a.domain.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); +// +// doReturn(topologyClusterA) +// .when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// RdsMultiAzDbClusterListProvider provider2 = +// Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-b.domain.com/")); +// provider2.init(); +// assertNull(provider2.getStoredTopology()); +// +// final List topologyClusterB = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); +// doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); +// +// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterB, topologyProvider2); +// +// assertEquals(2, storageService.size(Topology.class)); +// } +// +// @Test +// void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { +// RdsMultiAzDbClusterListProvider.clearAll(); +// +// RdsMultiAzDbClusterListProvider provider1 = +// Mockito.spy(getRdsMazDbClusterHostListProvider( +// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build()); +// +// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// RdsMultiAzDbClusterListProvider provider2 = +// Mockito.spy(getRdsMazDbClusterHostListProvider( +// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider2.init(); +// +// assertEquals(provider1.clusterId, provider2.clusterId); +// assertTrue(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// +// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider2); +// +// assertEquals(1, storageService.size(Topology.class)); +// } +// +// @Test +// void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { +// RdsMultiAzDbClusterListProvider.clearAll(); +// +// RdsMultiAzDbClusterListProvider provider1 = +// Mockito.spy(getRdsMazDbClusterHostListProvider( +// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build()); +// +// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// RdsMultiAzDbClusterListProvider provider2 = +// Mockito.spy(getRdsMazDbClusterHostListProvider( +// "jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); +// provider2.init(); +// +// assertEquals(provider1.clusterId, provider2.clusterId); +// assertTrue(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// +// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider2); +// +// assertEquals(1, storageService.size(Topology.class)); +// } +// +// @Test +// void testTopologyCache_AcceptSuggestion() throws SQLException { +// RdsMultiAzDbClusterListProvider.clearAll(); +// +// RdsMultiAzDbClusterListProvider provider1 = +// Mockito.spy(getRdsMazDbClusterHostListProvider( +// "jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); +// provider1.init(); +// final List topologyClusterA = Arrays.asList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build(), +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.READER) +// .build()); +// +// doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); +// +// assertEquals(0, storageService.size(Topology.class)); +// +// List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// +// // RdsMultiAzDbClusterListProvider.logCache(); +// +// RdsMultiAzDbClusterListProvider provider2 = +// Mockito.spy(getRdsMazDbClusterHostListProvider( +// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); +// provider2.init(); +// +// doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); +// +// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider2); +// +// assertNotEquals(provider1.clusterId, provider2.clusterId); +// assertFalse(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// assertEquals(2, storageService.size(Topology.class)); +// assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", +// RdsMultiAzDbClusterListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); +// +// // RdsMultiAzDbClusterListProvider.logCache(); +// +// topologyProvider1 = provider1.forceRefresh(Mockito.mock(Connection.class)); +// assertEquals(topologyClusterA, topologyProvider1); +// assertEquals(provider1.clusterId, provider2.clusterId); +// assertTrue(provider1.isPrimaryClusterId); +// assertTrue(provider2.isPrimaryClusterId); +// +// // RdsMultiAzDbClusterListProvider.logCache(); +// } +// +// @Test +// void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); +// +// when(mockResultSet.next()).thenReturn(false); +// assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); +// +// when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); +// assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); +// } +// +// @Test +// void testIdentifyConnectionNullTopology() throws SQLException { +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); +// rdsMazDbClusterHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("?.pattern").build(); +// +// when(mockResultSet.next()).thenReturn(true); +// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); +// doReturn(null).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); +// doReturn(null).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); +// +// assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); +// } +// +// @Test +// void testIdentifyConnectionHostNotInTopology() throws SQLException { +// final List cachedTopology = Collections.singletonList( +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build()); +// +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); +// when(mockResultSet.next()).thenReturn(true); +// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); +// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); +// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); +// +// assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); +// } +// +// @Test +// void testIdentifyConnectionHostInTopology() throws SQLException { +// final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") +// .hostId("instance-a-1") +// .port(HostSpec.NO_PORT) +// .role(HostRole.WRITER) +// .build(); +// final List cachedTopology = Collections.singletonList(expectedHost); +// +// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); +// when(mockResultSet.next()).thenReturn(true); +// when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); +// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); +// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); +// +// final HostSpec actual = rdsMazDbClusterHostListProvider.identifyConnection(mockConnection); +// assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); +// assertEquals("instance-a-1", actual.getHostId()); +// } +// +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java index 9ca4c86dd..6c0d439c0 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java @@ -1,156 +1,155 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.mock; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import software.amazon.jdbc.ConnectionPlugin; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.NodeChangeOptions; -import software.amazon.jdbc.OldConnectionSuggestedAction; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; - -public class TestPluginOne implements ConnectionPlugin { - - protected Set subscribedMethods; - protected ArrayList calls; - - TestPluginOne() {} - - public TestPluginOne(ArrayList calls) { - this.calls = calls; - - this.subscribedMethods = new HashSet<>(Arrays.asList("*")); - } - - @Override - public Set getSubscribedMethods() { - return this.subscribedMethods; - } - - @Override - public T execute( - Class resultClass, - Class exceptionClass, - Object methodInvokeOn, - String methodName, - JdbcCallable jdbcMethodFunc, - Object[] jdbcMethodArgs) - throws E { - - this.calls.add(this.getClass().getSimpleName() + ":before"); - - T result; - try { - result = jdbcMethodFunc.call(); - } catch (RuntimeException e) { - throw e; - } catch (Exception e) { - if (exceptionClass.isInstance(e)) { - throw exceptionClass.cast(e); - } - throw new RuntimeException(e); - } - - this.calls.add(this.getClass().getSimpleName() + ":after"); - - return result; - } - - @Override - public Connection connect( - String driverProtocol, - HostSpec hostSpec, - Properties props, - boolean isInitialConnection, - JdbcCallable connectFunc) - throws SQLException { - - this.calls.add(this.getClass().getSimpleName() + ":before connect"); - Connection result = connectFunc.call(); - this.calls.add(this.getClass().getSimpleName() + ":after connect"); - return result; - } - - @Override - public Connection forceConnect( - String driverProtocol, - HostSpec hostSpec, - Properties props, - boolean isInitialConnection, - JdbcCallable forceConnectFunc) - throws SQLException { - - this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); - Connection result = forceConnectFunc.call(); - this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); - return result; - } - - @Override - public boolean acceptsStrategy(HostRole role, String strategy) { - return false; - } - - @Override - public HostSpec getHostSpecByStrategy(HostRole role, String strategy) { - this.calls.add(this.getClass().getSimpleName() + ":before getHostSpecByStrategy"); - HostSpec result = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("host").port(1234).role(role).build(); - this.calls.add(this.getClass().getSimpleName() + ":after getHostSpecByStrategy"); - return result; - } - - @Override - public HostSpec getHostSpecByStrategy(List hosts, HostRole role, String strategy) { - return getHostSpecByStrategy(role, strategy); - } - - @Override - public void initHostProvider( - String driverProtocol, - String initialUrl, - Properties props, - HostListProviderService hostListProviderService, - JdbcCallable initHostProviderFunc) - throws SQLException { - - // do nothing - } - - @Override - public OldConnectionSuggestedAction notifyConnectionChanged(EnumSet changes) { - return OldConnectionSuggestedAction.NO_OPINION; - } - - @Override - public void notifyNodeListChanged(Map> changes) { - // do nothing - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.mock; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.EnumSet; +// import java.util.HashSet; +// import java.util.List; +// import java.util.Map; +// import java.util.Properties; +// import java.util.Set; +// import software.amazon.jdbc.ConnectionPlugin; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.NodeChangeOptions; +// import software.amazon.jdbc.OldConnectionSuggestedAction; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.util.connection.ConnectionContext; +// +// public class TestPluginOne implements ConnectionPlugin { +// +// protected Set subscribedMethods; +// protected ArrayList calls; +// +// TestPluginOne() {} +// +// public TestPluginOne(ArrayList calls) { +// this.calls = calls; +// +// this.subscribedMethods = new HashSet<>(Arrays.asList("*")); +// } +// +// @Override +// public Set getSubscribedMethods() { +// return this.subscribedMethods; +// } +// +// @Override +// public T execute( +// Class resultClass, +// Class exceptionClass, +// Object methodInvokeOn, +// String methodName, +// JdbcCallable jdbcMethodFunc, +// Object[] jdbcMethodArgs) +// throws E { +// +// this.calls.add(this.getClass().getSimpleName() + ":before"); +// +// T result; +// try { +// result = jdbcMethodFunc.call(); +// } catch (RuntimeException e) { +// throw e; +// } catch (Exception e) { +// if (exceptionClass.isInstance(e)) { +// throw exceptionClass.cast(e); +// } +// throw new RuntimeException(e); +// } +// +// this.calls.add(this.getClass().getSimpleName() + ":after"); +// +// return result; +// } +// +// @Override +// public Connection connect( +// final ConnectionContext connectionContext, +// final HostSpec hostSpec, +// final boolean isInitialConnection, +// final JdbcCallable connectFunc) throws SQLException { +// +// this.calls.add(this.getClass().getSimpleName() + ":before connect"); +// Connection result = connectFunc.call(); +// this.calls.add(this.getClass().getSimpleName() + ":after connect"); +// return result; +// } +// +// @Override +// public Connection forceConnect( +// String driverProtocol, +// HostSpec hostSpec, +// Properties props, +// boolean isInitialConnection, +// JdbcCallable forceConnectFunc) +// throws SQLException { +// +// this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); +// Connection result = forceConnectFunc.call(); +// this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); +// return result; +// } +// +// @Override +// public boolean acceptsStrategy(HostRole role, String strategy) { +// return false; +// } +// +// @Override +// public HostSpec getHostSpecByStrategy(HostRole role, String strategy) { +// this.calls.add(this.getClass().getSimpleName() + ":before getHostSpecByStrategy"); +// HostSpec result = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("host").port(1234).role(role).build(); +// this.calls.add(this.getClass().getSimpleName() + ":after getHostSpecByStrategy"); +// return result; +// } +// +// @Override +// public HostSpec getHostSpecByStrategy(List hosts, HostRole role, String strategy) { +// return getHostSpecByStrategy(role, strategy); +// } +// +// @Override +// public void initHostProvider( +// String driverProtocol, +// String initialUrl, +// Properties props, +// HostListProviderService hostListProviderService, +// JdbcCallable initHostProviderFunc) +// throws SQLException { +// +// // do nothing +// } +// +// @Override +// public OldConnectionSuggestedAction notifyConnectionChanged(EnumSet changes) { +// return OldConnectionSuggestedAction.NO_OPINION; +// } +// +// @Override +// public void notifyNodeListChanged(Map> changes) { +// // do nothing +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java index 6166a788a..f62a93499 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java @@ -1,88 +1,87 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.mock; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Properties; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.JdbcMethod; - -public class TestPluginThree extends TestPluginOne { - - private Connection connection; - - public TestPluginThree(ArrayList calls) { - super(); - this.calls = calls; - - this.subscribedMethods = new HashSet<>(Arrays.asList( - JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.CONNECT.methodName, JdbcMethod.FORCECONNECT.methodName)); - } - - public TestPluginThree(ArrayList calls, Connection connection) { - this(calls); - this.connection = connection; - } - - @Override - public Connection connect( - String driverProtocol, - HostSpec hostSpec, - Properties props, - boolean isInitialConnection, - JdbcCallable connectFunc) - throws SQLException { - - this.calls.add(this.getClass().getSimpleName() + ":before connect"); - - if (this.connection != null) { - this.calls.add(this.getClass().getSimpleName() + ":connection"); - return this.connection; - } - - Connection result = connectFunc.call(); - this.calls.add(this.getClass().getSimpleName() + ":after connect"); - - return result; - } - - public Connection forceConnect( - String driverProtocol, - HostSpec hostSpec, - Properties props, - boolean isInitialConnection, - JdbcCallable forceConnectFunc) - throws SQLException { - - this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); - - if (this.connection != null) { - this.calls.add(this.getClass().getSimpleName() + ":forced connection"); - return this.connection; - } - - Connection result = forceConnectFunc.call(); - this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); - - return result; - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.mock; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.HashSet; +// import java.util.Properties; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.JdbcMethod; +// import software.amazon.jdbc.util.connection.ConnectionContext; +// +// public class TestPluginThree extends TestPluginOne { +// +// private Connection connection; +// +// public TestPluginThree(ArrayList calls) { +// super(); +// this.calls = calls; +// +// this.subscribedMethods = new HashSet<>(Arrays.asList( +// JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.CONNECT.methodName, JdbcMethod.FORCECONNECT.methodName)); +// } +// +// public TestPluginThree(ArrayList calls, Connection connection) { +// this(calls); +// this.connection = connection; +// } +// +// @Override +// public Connection connect( +// final ConnectionContext connectionContext, +// final HostSpec hostSpec, +// final boolean isInitialConnection, +// final JdbcCallable connectFunc) throws SQLException { +// +// this.calls.add(this.getClass().getSimpleName() + ":before connect"); +// +// if (this.connection != null) { +// this.calls.add(this.getClass().getSimpleName() + ":connection"); +// return this.connection; +// } +// +// Connection result = connectFunc.call(); +// this.calls.add(this.getClass().getSimpleName() + ":after connect"); +// +// return result; +// } +// +// public Connection forceConnect( +// String driverProtocol, +// HostSpec hostSpec, +// Properties props, +// boolean isInitialConnection, +// JdbcCallable forceConnectFunc) +// throws SQLException { +// +// this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); +// +// if (this.connection != null) { +// this.calls.add(this.getClass().getSimpleName() + ":forced connection"); +// return this.connection; +// } +// +// Connection result = forceConnectFunc.call(); +// this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); +// +// return result; +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java index 7e793781e..84bdfd414 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java @@ -1,113 +1,112 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.mock; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Properties; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.JdbcCallable; - -public class TestPluginThrowException extends TestPluginOne { - - protected final Class exceptionClass; - protected final boolean isBefore; - - public TestPluginThrowException( - ArrayList calls, Class exceptionClass, boolean isBefore) { - super(); - this.calls = calls; - this.exceptionClass = exceptionClass; - this.isBefore = isBefore; - - this.subscribedMethods = new HashSet<>(Arrays.asList("*")); - } - - @Override - public T execute( - Class resultClass, - Class exceptionClass, - Object methodInvokeOn, - String methodName, - JdbcCallable jdbcMethodFunc, - Object[] jdbcMethodArgs) - throws E { - - this.calls.add(this.getClass().getSimpleName() + ":before"); - if (this.isBefore) { - try { - throw this.exceptionClass.newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - T result = jdbcMethodFunc.call(); - - this.calls.add(this.getClass().getSimpleName() + ":after"); - //noinspection ConstantConditions - if (!this.isBefore) { - try { - throw this.exceptionClass.newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - return result; - } - - @Override - public Connection connect( - String driverProtocol, - HostSpec hostSpec, - Properties props, - boolean isInitialConnection, - JdbcCallable connectFunc) - throws SQLException { - - this.calls.add(this.getClass().getSimpleName() + ":before"); - if (this.isBefore) { - throwException(); - } - - Connection conn = connectFunc.call(); - - this.calls.add(this.getClass().getSimpleName() + ":after"); - if (!this.isBefore) { - throwException(); - } - - return conn; - } - - private void throwException() throws SQLException { - try { - throw this.exceptionClass.newInstance(); - } catch (RuntimeException e) { - throw e; - } catch (Exception e) { - if (e instanceof SQLException) { - throw (SQLException) e; - } - throw new SQLException(e); - } - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.mock; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.HashSet; +// import java.util.Properties; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.util.connection.ConnectionContext; +// +// public class TestPluginThrowException extends TestPluginOne { +// +// protected final Class exceptionClass; +// protected final boolean isBefore; +// +// public TestPluginThrowException( +// ArrayList calls, Class exceptionClass, boolean isBefore) { +// super(); +// this.calls = calls; +// this.exceptionClass = exceptionClass; +// this.isBefore = isBefore; +// +// this.subscribedMethods = new HashSet<>(Arrays.asList("*")); +// } +// +// @Override +// public T execute( +// Class resultClass, +// Class exceptionClass, +// Object methodInvokeOn, +// String methodName, +// JdbcCallable jdbcMethodFunc, +// Object[] jdbcMethodArgs) +// throws E { +// +// this.calls.add(this.getClass().getSimpleName() + ":before"); +// if (this.isBefore) { +// try { +// throw this.exceptionClass.newInstance(); +// } catch (Exception e) { +// throw new RuntimeException(e); +// } +// } +// +// T result = jdbcMethodFunc.call(); +// +// this.calls.add(this.getClass().getSimpleName() + ":after"); +// //noinspection ConstantConditions +// if (!this.isBefore) { +// try { +// throw this.exceptionClass.newInstance(); +// } catch (Exception e) { +// throw new RuntimeException(e); +// } +// } +// +// return result; +// } +// +// @Override +// public Connection connect( +// final ConnectionContext connectionContext, +// final HostSpec hostSpec, +// final boolean isInitialConnection, +// final JdbcCallable connectFunc) throws SQLException { +// +// this.calls.add(this.getClass().getSimpleName() + ":before"); +// if (this.isBefore) { +// throwException(); +// } +// +// Connection conn = connectFunc.call(); +// +// this.calls.add(this.getClass().getSimpleName() + ":after"); +// if (!this.isBefore) { +// throwException(); +// } +// +// return conn; +// } +// +// private void throwException() throws SQLException { +// try { +// throw this.exceptionClass.newInstance(); +// } catch (RuntimeException e) { +// throw e; +// } catch (Exception e) { +// if (e instanceof SQLException) { +// throw (SQLException) e; +// } +// throw new SQLException(e); +// } +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java index 82af4b437..0a96cd136 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java @@ -1,33 +1,33 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.mock; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import software.amazon.jdbc.JdbcMethod; - -public class TestPluginTwo extends TestPluginOne { - - public TestPluginTwo(ArrayList calls) { - super(); - this.calls = calls; - - this.subscribedMethods = new HashSet<>( - Arrays.asList(JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.BLOB_POSITION.methodName)); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.mock; +// +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.HashSet; +// import software.amazon.jdbc.JdbcMethod; +// +// public class TestPluginTwo extends TestPluginOne { +// +// public TestPluginTwo(ArrayList calls) { +// super(); +// this.calls = calls; +// +// this.subscribedMethods = new HashSet<>( +// Arrays.asList(JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.BLOB_POSITION.methodName)); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java index 70d62ae9f..8d00543a1 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java @@ -1,245 +1,245 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Collections; -import java.util.HashSet; -import java.util.Properties; -import java.util.Set; -import java.util.stream.Stream; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.plugin.failover.FailoverSQLException; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.RdsUrlType; -import software.amazon.jdbc.util.RdsUtils; - -public class AuroraConnectionTrackerPluginTest { - - public static final Properties EMPTY_PROPERTIES = new Properties(); - @Mock Connection mockConnection; - @Mock Statement mockStatement; - @Mock ResultSet mockResultSet; - @Mock PluginService mockPluginService; - @Mock Dialect mockTopologyAwareDialect; - @Mock RdsUtils mockRdsUtils; - @Mock OpenedConnectionTracker mockTracker; - @Mock JdbcCallable mockConnectionFunction; - @Mock JdbcCallable mockSqlFunction; - @Mock JdbcCallable mockCloseOrAbortFunction; - @Mock TargetDriverDialect mockTargetDriverDialect; - - private static final Object[] SQL_ARGS = {"sql"}; - - private AutoCloseable closeable; - - - @BeforeEach - void setUp() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - when(mockConnectionFunction.call()).thenReturn(mockConnection); - when(mockSqlFunction.call()).thenReturn(mockResultSet); - when(mockConnection.createStatement()).thenReturn(mockStatement); - when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); - when(mockRdsUtils.getRdsInstanceHostPattern(any(String.class))).thenReturn("?"); - when(mockRdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.getDialect()).thenReturn(mockTopologyAwareDialect); - when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); - when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @ParameterizedTest - @MethodSource("trackNewConnectionsParameters") - public void testTrackNewInstanceConnections( - final String protocol, - final boolean isInitialConnection) throws SQLException { - final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance1") - .build(); - when(mockPluginService.getCurrentHostSpec()).thenReturn(hostSpec); - when(mockRdsUtils.isRdsInstance("instance1")).thenReturn(true); - - final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( - mockPluginService, - EMPTY_PROPERTIES, - mockRdsUtils, - mockTracker); - - final Connection actualConnection = plugin.connect( - protocol, - hostSpec, - EMPTY_PROPERTIES, - isInitialConnection, - mockConnectionFunction); - - assertEquals(mockConnection, actualConnection); - verify(mockTracker).populateOpenedConnectionQueue(eq(hostSpec), eq(mockConnection)); - final Set aliases = hostSpec.getAliases(); - assertEquals(0, aliases.size()); - } - - @Test - public void testInvalidateOpenedConnectionsWhenWriterHostNotChange() throws SQLException { - final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); - final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("host") - .role(HostRole.WRITER) - .build(); - final HostSpec newHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("new-host") - .role(HostRole.WRITER) - .build(); - - // Host list changes during simulated failover - when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList(originalHost)); - doThrow(expectedException).when(mockSqlFunction).call(); - - final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( - mockPluginService, - EMPTY_PROPERTIES, - mockRdsUtils, - mockTracker); - - final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( - ResultSet.class, - SQLException.class, - Statement.class, - "Statement.executeQuery", - mockSqlFunction, - SQL_ARGS - )); - - assertEquals(expectedException, exception); - verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); - verify(mockTracker, never()).invalidateAllConnections(originalHost); - } - - @Test - public void testInvalidateOpenedConnectionsWhenWriterHostChanged() throws SQLException { - final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); - final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") - .build(); - final HostSpec failoverTargetHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2") - .build(); - when(mockPluginService.getAllHosts()) - .thenReturn(Collections.singletonList(originalHost)) - .thenReturn(Collections.singletonList(failoverTargetHost)); - when(mockSqlFunction.call()) - .thenReturn(mockResultSet) - .thenThrow(expectedException); - - final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( - mockPluginService, - EMPTY_PROPERTIES, - mockRdsUtils, - mockTracker); - - plugin.execute( - ResultSet.class, - SQLException.class, - Statement.class, - "Statement.executeQuery", - mockSqlFunction, - SQL_ARGS - ); - - final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( - ResultSet.class, - SQLException.class, - Statement.class, - "Statement.executeQuery", - mockSqlFunction, - SQL_ARGS - )); - assertEquals(expectedException, exception); - verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); - verify(mockTracker).invalidateAllConnections(originalHost); - } - - @ParameterizedTest - @MethodSource("testInvalidateConnectionsOnCloseOrAbortArgs") - public void testInvalidateConnectionsOnCloseOrAbort(final String method) throws SQLException { - final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") - .build(); - when(mockPluginService.getCurrentHostSpec()).thenReturn(originalHost); - - final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( - mockPluginService, - EMPTY_PROPERTIES, - mockRdsUtils, - mockTracker); - - plugin.execute( - Void.class, - SQLException.class, - Connection.class, - method, - mockCloseOrAbortFunction, - SQL_ARGS - ); - - verify(mockTracker).removeConnectionTracking(eq(originalHost), eq(mockConnection)); - } - - static Stream testInvalidateConnectionsOnCloseOrAbortArgs() { - return Stream.of( - Arguments.of(JdbcMethod.CONNECTION_ABORT.methodName), - Arguments.of(JdbcMethod.CONNECTION_CLOSE.methodName) - ); - } - - private static Stream trackNewConnectionsParameters() { - return Stream.of( - Arguments.of("postgresql", true), - Arguments.of("postgresql", false), - Arguments.of("otherProtocol", true), - Arguments.of("otherProtocol", false) - ); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doThrow; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.Collections; +// import java.util.HashSet; +// import java.util.Properties; +// import java.util.Set; +// import java.util.stream.Stream; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.Arguments; +// import org.junit.jupiter.params.provider.MethodSource; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.JdbcMethod; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.plugin.failover.FailoverSQLException; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.RdsUrlType; +// import software.amazon.jdbc.util.RdsUtils; +// +// public class AuroraConnectionTrackerPluginTest { +// +// public static final Properties EMPTY_PROPERTIES = new Properties(); +// @Mock Connection mockConnection; +// @Mock Statement mockStatement; +// @Mock ResultSet mockResultSet; +// @Mock PluginService mockPluginService; +// @Mock Dialect mockTopologyAwareDialect; +// @Mock RdsUtils mockRdsUtils; +// @Mock OpenedConnectionTracker mockTracker; +// @Mock JdbcCallable mockConnectionFunction; +// @Mock JdbcCallable mockSqlFunction; +// @Mock JdbcCallable mockCloseOrAbortFunction; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// +// private static final Object[] SQL_ARGS = {"sql"}; +// +// private AutoCloseable closeable; +// +// +// @BeforeEach +// void setUp() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// when(mockConnectionFunction.call()).thenReturn(mockConnection); +// when(mockSqlFunction.call()).thenReturn(mockResultSet); +// when(mockConnection.createStatement()).thenReturn(mockStatement); +// when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); +// when(mockRdsUtils.getRdsInstanceHostPattern(any(String.class))).thenReturn("?"); +// when(mockRdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); +// when(mockPluginService.getDialect()).thenReturn(mockTopologyAwareDialect); +// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); +// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @ParameterizedTest +// @MethodSource("trackNewConnectionsParameters") +// public void testTrackNewInstanceConnections( +// final String protocol, +// final boolean isInitialConnection) throws SQLException { +// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance1") +// .build(); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(hostSpec); +// when(mockRdsUtils.isRdsInstance("instance1")).thenReturn(true); +// +// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( +// mockPluginService, +// EMPTY_PROPERTIES, +// mockRdsUtils, +// mockTracker); +// +// final Connection actualConnection = plugin.connect( +// protocol, +// hostSpec, +// EMPTY_PROPERTIES, +// isInitialConnection, +// mockConnectionFunction); +// +// assertEquals(mockConnection, actualConnection); +// verify(mockTracker).populateOpenedConnectionQueue(eq(hostSpec), eq(mockConnection)); +// final Set aliases = hostSpec.getAliases(); +// assertEquals(0, aliases.size()); +// } +// +// @Test +// public void testInvalidateOpenedConnectionsWhenWriterHostNotChange() throws SQLException { +// final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); +// final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("host") +// .role(HostRole.WRITER) +// .build(); +// final HostSpec newHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("new-host") +// .role(HostRole.WRITER) +// .build(); +// +// // Host list changes during simulated failover +// when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList(originalHost)); +// doThrow(expectedException).when(mockSqlFunction).call(); +// +// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( +// mockPluginService, +// EMPTY_PROPERTIES, +// mockRdsUtils, +// mockTracker); +// +// final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( +// ResultSet.class, +// SQLException.class, +// Statement.class, +// "Statement.executeQuery", +// mockSqlFunction, +// SQL_ARGS +// )); +// +// assertEquals(expectedException, exception); +// verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); +// verify(mockTracker, never()).invalidateAllConnections(originalHost); +// } +// +// @Test +// public void testInvalidateOpenedConnectionsWhenWriterHostChanged() throws SQLException { +// final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); +// final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") +// .build(); +// final HostSpec failoverTargetHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2") +// .build(); +// when(mockPluginService.getAllHosts()) +// .thenReturn(Collections.singletonList(originalHost)) +// .thenReturn(Collections.singletonList(failoverTargetHost)); +// when(mockSqlFunction.call()) +// .thenReturn(mockResultSet) +// .thenThrow(expectedException); +// +// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( +// mockPluginService, +// EMPTY_PROPERTIES, +// mockRdsUtils, +// mockTracker); +// +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// Statement.class, +// "Statement.executeQuery", +// mockSqlFunction, +// SQL_ARGS +// ); +// +// final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( +// ResultSet.class, +// SQLException.class, +// Statement.class, +// "Statement.executeQuery", +// mockSqlFunction, +// SQL_ARGS +// )); +// assertEquals(expectedException, exception); +// verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); +// verify(mockTracker).invalidateAllConnections(originalHost); +// } +// +// @ParameterizedTest +// @MethodSource("testInvalidateConnectionsOnCloseOrAbortArgs") +// public void testInvalidateConnectionsOnCloseOrAbort(final String method) throws SQLException { +// final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") +// .build(); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(originalHost); +// +// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( +// mockPluginService, +// EMPTY_PROPERTIES, +// mockRdsUtils, +// mockTracker); +// +// plugin.execute( +// Void.class, +// SQLException.class, +// Connection.class, +// method, +// mockCloseOrAbortFunction, +// SQL_ARGS +// ); +// +// verify(mockTracker).removeConnectionTracking(eq(originalHost), eq(mockConnection)); +// } +// +// static Stream testInvalidateConnectionsOnCloseOrAbortArgs() { +// return Stream.of( +// Arguments.of(JdbcMethod.CONNECTION_ABORT.methodName), +// Arguments.of(JdbcMethod.CONNECTION_CLOSE.methodName) +// ); +// } +// +// private static Stream trackNewConnectionsParameters() { +// return Stream.of( +// Arguments.of("postgresql", true), +// Arguments.of("postgresql", false), +// Arguments.of("otherProtocol", true), +// Arguments.of("otherProtocol", false) +// ); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java index 22a8339c1..e7517c933 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java @@ -1,516 +1,516 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.REGION_PROPERTY; -import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.SECRET_ID_PROPERTY; - -import com.mysql.cj.exceptions.CJException; -import java.sql.Connection; -import java.sql.SQLException; -import java.util.Properties; -import java.util.stream.Stream; -import org.jetbrains.annotations.NotNull; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.postgresql.util.PSQLException; -import org.postgresql.util.PSQLState; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; -import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; -import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; -import software.amazon.awssdk.services.secretsmanager.model.SecretsManagerException; -import software.amazon.jdbc.ConnectionPluginManager; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.PluginServiceImpl; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.dialect.DialectManager; -import software.amazon.jdbc.exceptions.ExceptionHandler; -import software.amazon.jdbc.exceptions.ExceptionManager; -import software.amazon.jdbc.exceptions.MySQLExceptionHandler; -import software.amazon.jdbc.exceptions.PgExceptionHandler; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.profile.ConfigurationProfile; -import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -import software.amazon.jdbc.states.SessionStateService; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.Pair; -import software.amazon.jdbc.util.telemetry.GaugeCallable; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; - -@SuppressWarnings("resource") -public class AwsSecretsManagerConnectionPluginTest { - - private static final String TEST_PG_PROTOCOL = "jdbc:aws-wrapper:postgresql:"; - private static final String TEST_MYSQL_PROTOCOL = "jdbc:aws-wrapper:mysql:"; - private static final String TEST_REGION = "us-east-2"; - private static final String TEST_SECRET_ID = "secretId"; - private static final String TEST_USERNAME = "testUser"; - private static final String TEST_PASSWORD = "testPassword"; - private static final String VALID_SECRET_STRING = - "{\"username\": \"" + TEST_USERNAME + "\", \"password\": \"" + TEST_PASSWORD + "\"}"; - private static final String INVALID_SECRET_STRING = "{username: invalid, password: invalid}"; - private static final String TEST_HOST = "test-domain"; - private static final String TEST_SQL_ERROR = "SQL exception error message"; - private static final String UNHANDLED_ERROR_CODE = "HY000"; - private static final int TEST_PORT = 5432; - private static final Pair SECRET_CACHE_KEY = Pair.create(TEST_SECRET_ID, TEST_REGION); - private static final AwsSecretsManagerConnectionPlugin.Secret TEST_SECRET = - new AwsSecretsManagerConnectionPlugin.Secret("testUser", "testPassword"); - private static final HostSpec TEST_HOSTSPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host(TEST_HOST).port(TEST_PORT).build(); - private static final GetSecretValueResponse VALID_GET_SECRET_VALUE_RESPONSE = - GetSecretValueResponse.builder().secretString(VALID_SECRET_STRING).build(); - private static final GetSecretValueResponse INVALID_GET_SECRET_VALUE_RESPONSE = - GetSecretValueResponse.builder().secretString(INVALID_SECRET_STRING).build(); - private static final Properties TEST_PROPS = new Properties(); - private AwsSecretsManagerConnectionPlugin plugin; - - private AutoCloseable closeable; - - @Mock FullServicesContainer mockServicesContainer; - @Mock SecretsManagerClient mockSecretsManagerClient; - @Mock GetSecretValueRequest mockGetValueRequest; - @Mock JdbcCallable connectFunc; - @Mock PluginServiceImpl mockService; - @Mock ConnectionPluginManager mockConnectionPluginManager; - @Mock Dialect mockTopologyAwareDialect; - @Mock DialectManager mockDialectManager; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock TelemetryContext mockTelemetryContext; - @Mock TelemetryCounter mockTelemetryCounter; - @Mock TelemetryGauge mockTelemetryGauge; - @Mock TargetDriverDialect mockTargetDriverDialect; - ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); - - @Mock SessionStateService mockSessionStateService; - - @BeforeEach - public void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - - REGION_PROPERTY.set(TEST_PROPS, TEST_REGION); - SECRET_ID_PROPERTY.set(TEST_PROPS, TEST_SECRET_ID); - - when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) - .thenReturn(mockTopologyAwareDialect); - - when(mockServicesContainer.getConnectionPluginManager()).thenReturn(mockConnectionPluginManager); - when(mockService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); - // noinspection unchecked - when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - - this.plugin = new AwsSecretsManagerConnectionPlugin( - mockService, - TEST_PROPS, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest); - - when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) - .thenReturn(mockTopologyAwareDialect); - - when(mockService.getHostSpecBuilder()).thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - AwsSecretsManagerCacheHolder.clearCache(); - TEST_PROPS.clear(); - } - - /** - * The plugin will successfully open a connection with a cached secret. - */ - @Test - public void testConnectWithCachedSecrets() throws SQLException { - // Add initial cached secret to be used for a connection. - AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); - - this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); - - assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); - verify(this.connectFunc).call(); - assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); - assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); - } - - /** - * The plugin will attempt to open a connection with an empty secret cache. The plugin will fetch the secret from the - * AWS Secrets Manager. - */ - @Test - public void testConnectWithNewSecrets() throws SQLException { - when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) - .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); - - this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); - - assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); - verify(connectFunc).call(); - assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); - assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); - } - - @ParameterizedTest - @MethodSource("missingArguments") - public void testMissingRequiredParameters(final Properties properties) { - assertThrows(RuntimeException.class, () -> new AwsSecretsManagerConnectionPlugin( - mockService, - properties, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest)); - } - - /** - * The plugin will attempt to open a connection with a cached secret, but it will fail with a generic SQL exception. - * In this case, the plugin will rethrow the error back to the user. - */ - @Test - public void testFailedInitialConnectionWithUnhandledError() throws SQLException { - AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); - final SQLException failedFirstConnectionGenericException = new SQLException(TEST_SQL_ERROR, UNHANDLED_ERROR_CODE); - doThrow(failedFirstConnectionGenericException).when(connectFunc).call(); - - final SQLException connectionFailedException = assertThrows( - SQLException.class, - () -> this.plugin.connect( - TEST_PG_PROTOCOL, - TEST_HOSTSPEC, - TEST_PROPS, - true, - this.connectFunc)); - - assertEquals(TEST_SQL_ERROR, connectionFailedException.getMessage()); - verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); - verify(connectFunc).call(); - assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); - assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); - } - - /** - * The plugin will attempt to open a connection with a cached secret, but it will fail with an access error. In this - * case, the plugin will fetch the secret and will retry the connection. - */ - @ParameterizedTest - @MethodSource("provideExceptionCodeForDifferentDrivers") - public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( - String accessError, - String protocol, - ExceptionHandler exceptionHandler) throws SQLException { - this.plugin = new AwsSecretsManagerConnectionPlugin( - getPluginService(protocol), - TEST_PROPS, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest); - - // Fail the initial connection attempt with cached secret. - // Second attempt should be successful. - AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); - final SQLException failedFirstConnectionAccessException = new SQLException(TEST_SQL_ERROR, - accessError); - doThrow(failedFirstConnectionAccessException).when(connectFunc).call(); - when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) - .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); - - when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(exceptionHandler); - - assertThrows( - SQLException.class, - () -> this.plugin.connect( - TEST_PG_PROTOCOL, - TEST_HOSTSPEC, - TEST_PROPS, - true, - this.connectFunc)); - - assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); - verify(connectFunc, times(2)).call(); - assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); - assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); - } - - private @NotNull PluginServiceImpl getPluginService(String protocol) throws SQLException { - return new PluginServiceImpl( - mockServicesContainer, - new ExceptionManager(), - TEST_PROPS, - "url", - protocol, - mockDialectManager, - mockTargetDriverDialect, - configurationProfile, - mockSessionStateService); - } - - /** - * The plugin will attempt to open a connection after fetching a secret, but it will fail because the returned secret - * could not be parsed. - */ - @Test - public void testFailedToReadSecrets() throws SQLException { - when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) - .thenReturn(INVALID_GET_SECRET_VALUE_RESPONSE); - - final SQLException readSecretsFailedException = - assertThrows( - SQLException.class, - () -> this.plugin.connect( - TEST_PG_PROTOCOL, - TEST_HOSTSPEC, - TEST_PROPS, - true, - this.connectFunc)); - - assertEquals( - readSecretsFailedException.getMessage(), - Messages.get( - "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); - assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); - verify(this.connectFunc, never()).call(); - } - - /** - * The plugin will attempt to open a connection after fetching a secret, but it will fail because an exception was - * thrown by the AWS Secrets Manager. - */ - @Test - public void testFailedToGetSecrets() throws SQLException { - doThrow(SecretsManagerException.class).when(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); - - final SQLException getSecretsFailedException = - assertThrows( - SQLException.class, - () -> this.plugin.connect( - TEST_PG_PROTOCOL, - TEST_HOSTSPEC, - TEST_PROPS, - true, - this.connectFunc)); - - assertEquals( - getSecretsFailedException.getMessage(), - Messages.get( - "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); - assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); - verify(this.connectFunc, never()).call(); - } - - @ParameterizedTest - @ValueSource(strings = {"28000", "28P01"}) - public void testFailedInitialConnectionWithWrappedGenericError(final String accessError) throws SQLException { - this.plugin = new AwsSecretsManagerConnectionPlugin( - getPluginService(TEST_PG_PROTOCOL), - TEST_PROPS, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest); - - // Fail the initial connection attempt with a wrapped exception. - // Second attempt should be successful. - final SQLException targetException = new SQLException(TEST_SQL_ERROR, accessError); - final SQLException wrappedException = new SQLException(targetException); - doThrow(wrappedException).when(connectFunc).call(); - when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) - .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); - - when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); - - assertThrows( - SQLException.class, - () -> this.plugin.connect( - TEST_PG_PROTOCOL, - TEST_HOSTSPEC, - TEST_PROPS, - true, - this.connectFunc)); - - assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(connectFunc).call(); - assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); - assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); - } - - @Test - public void testConnectWithWrappedMySQLException() throws SQLException { - this.plugin = new AwsSecretsManagerConnectionPlugin( - getPluginService(TEST_MYSQL_PROTOCOL), - TEST_PROPS, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest); - - final CJException targetException = new CJException("28000"); - final SQLException wrappedException = new SQLException(targetException); - - doThrow(wrappedException).when(connectFunc).call(); - when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) - .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); - - when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); - - assertThrows( - SQLException.class, - () -> this.plugin.connect( - TEST_MYSQL_PROTOCOL, - TEST_HOSTSPEC, - TEST_PROPS, - true, - this.connectFunc)); - - assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(connectFunc).call(); - assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); - assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); - } - - @Test - public void testConnectWithWrappedPostgreSQLException() throws SQLException { - this.plugin = new AwsSecretsManagerConnectionPlugin( - getPluginService(TEST_PG_PROTOCOL), - TEST_PROPS, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest); - - final PSQLException targetException = new PSQLException("login error", PSQLState.INVALID_PASSWORD, null); - final SQLException wrappedException = new SQLException(targetException); - - doThrow(wrappedException).when(connectFunc).call(); - when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) - .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); - - when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); - - assertThrows( - SQLException.class, - () -> this.plugin.connect( - TEST_PG_PROTOCOL, - TEST_HOSTSPEC, - TEST_PROPS, - true, - this.connectFunc)); - - assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); - verify(connectFunc).call(); - assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); - assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); - } - - @ParameterizedTest - @MethodSource("arnArguments") - public void testConnectViaARN(final String arn, final Region expectedRegionParsedFromARN) - throws SQLException { - final Properties props = new Properties(); - - SECRET_ID_PROPERTY.set(props, arn); - - this.plugin = spy(new AwsSecretsManagerConnectionPlugin( - new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), - props, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest)); - - final Pair secret = this.plugin.secretKey; - assertEquals(expectedRegionParsedFromARN, Region.of(secret.getValue2())); - } - - @ParameterizedTest - @MethodSource("arnArguments") - public void testConnectionWithRegionParameterAndARN(final String arn, final Region regionParsedFromARN) - throws SQLException { - final Region expectedRegion = Region.US_ISO_EAST_1; - - final Properties props = new Properties(); - SECRET_ID_PROPERTY.set(props, arn); - REGION_PROPERTY.set(props, expectedRegion.toString()); - - this.plugin = spy(new AwsSecretsManagerConnectionPlugin( - new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), - props, - (host, r) -> mockSecretsManagerClient, - (id) -> mockGetValueRequest)); - - final Pair secret = this.plugin.secretKey; - // The region specified in `secretsManagerRegion` should override the region parsed from ARN. - assertNotEquals(regionParsedFromARN, Region.of(secret.getValue2())); - assertEquals(expectedRegion, Region.of(secret.getValue2())); - } - - private static Stream provideExceptionCodeForDifferentDrivers() { - return Stream.of( - Arguments.of("28000", TEST_MYSQL_PROTOCOL, new MySQLExceptionHandler()), - Arguments.of("28P01", TEST_PG_PROTOCOL, new PgExceptionHandler()) - ); - } - - private static Stream arnArguments() { - return Stream.of( - Arguments.of("arn:aws:secretsmanager:us-east-2:123456789012:secret:foo", Region.US_EAST_2), - Arguments.of("arn:aws:secretsmanager:us-west-1:123456789012:secret:boo", Region.US_WEST_1), - Arguments.of( - "arn:aws:secretsmanager:us-east-2:123456789012:secret:rds!cluster-bar-foo", - Region.US_EAST_2) - ); - } - - private static Stream missingArguments() { - final Properties missingId = new Properties(); - REGION_PROPERTY.set(missingId, TEST_REGION); - - final Properties missingRegion = new Properties(); - SECRET_ID_PROPERTY.set(missingRegion, TEST_SECRET_ID); - - return Stream.of( - Arguments.of(missingId), - Arguments.of(missingRegion) - ); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertNotEquals; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doThrow; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.REGION_PROPERTY; +// import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.SECRET_ID_PROPERTY; +// +// import com.mysql.cj.exceptions.CJException; +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.Properties; +// import java.util.stream.Stream; +// import org.jetbrains.annotations.NotNull; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.Arguments; +// import org.junit.jupiter.params.provider.MethodSource; +// import org.junit.jupiter.params.provider.ValueSource; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import org.postgresql.util.PSQLException; +// import org.postgresql.util.PSQLState; +// import software.amazon.awssdk.regions.Region; +// import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +// import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +// import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; +// import software.amazon.awssdk.services.secretsmanager.model.SecretsManagerException; +// import software.amazon.jdbc.ConnectionPluginManager; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.PluginServiceImpl; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.dialect.DialectManager; +// import software.amazon.jdbc.exceptions.ExceptionHandler; +// import software.amazon.jdbc.exceptions.ExceptionManager; +// import software.amazon.jdbc.exceptions.MySQLExceptionHandler; +// import software.amazon.jdbc.exceptions.PgExceptionHandler; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.profile.ConfigurationProfile; +// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +// import software.amazon.jdbc.states.SessionStateService; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.Messages; +// import software.amazon.jdbc.util.Pair; +// import software.amazon.jdbc.util.telemetry.GaugeCallable; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.util.telemetry.TelemetryGauge; +// +// @SuppressWarnings("resource") +// public class AwsSecretsManagerConnectionPluginTest { +// +// private static final String TEST_PG_PROTOCOL = "jdbc:aws-wrapper:postgresql:"; +// private static final String TEST_MYSQL_PROTOCOL = "jdbc:aws-wrapper:mysql:"; +// private static final String TEST_REGION = "us-east-2"; +// private static final String TEST_SECRET_ID = "secretId"; +// private static final String TEST_USERNAME = "testUser"; +// private static final String TEST_PASSWORD = "testPassword"; +// private static final String VALID_SECRET_STRING = +// "{\"username\": \"" + TEST_USERNAME + "\", \"password\": \"" + TEST_PASSWORD + "\"}"; +// private static final String INVALID_SECRET_STRING = "{username: invalid, password: invalid}"; +// private static final String TEST_HOST = "test-domain"; +// private static final String TEST_SQL_ERROR = "SQL exception error message"; +// private static final String UNHANDLED_ERROR_CODE = "HY000"; +// private static final int TEST_PORT = 5432; +// private static final Pair SECRET_CACHE_KEY = Pair.create(TEST_SECRET_ID, TEST_REGION); +// private static final AwsSecretsManagerConnectionPlugin.Secret TEST_SECRET = +// new AwsSecretsManagerConnectionPlugin.Secret("testUser", "testPassword"); +// private static final HostSpec TEST_HOSTSPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host(TEST_HOST).port(TEST_PORT).build(); +// private static final GetSecretValueResponse VALID_GET_SECRET_VALUE_RESPONSE = +// GetSecretValueResponse.builder().secretString(VALID_SECRET_STRING).build(); +// private static final GetSecretValueResponse INVALID_GET_SECRET_VALUE_RESPONSE = +// GetSecretValueResponse.builder().secretString(INVALID_SECRET_STRING).build(); +// private static final Properties TEST_PROPS = new Properties(); +// private AwsSecretsManagerConnectionPlugin plugin; +// +// private AutoCloseable closeable; +// +// @Mock FullServicesContainer mockServicesContainer; +// @Mock SecretsManagerClient mockSecretsManagerClient; +// @Mock GetSecretValueRequest mockGetValueRequest; +// @Mock JdbcCallable connectFunc; +// @Mock PluginServiceImpl mockService; +// @Mock ConnectionPluginManager mockConnectionPluginManager; +// @Mock Dialect mockTopologyAwareDialect; +// @Mock DialectManager mockDialectManager; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock TelemetryCounter mockTelemetryCounter; +// @Mock TelemetryGauge mockTelemetryGauge; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); +// +// @Mock SessionStateService mockSessionStateService; +// +// @BeforeEach +// public void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// +// REGION_PROPERTY.set(TEST_PROPS, TEST_REGION); +// SECRET_ID_PROPERTY.set(TEST_PROPS, TEST_SECRET_ID); +// +// when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) +// .thenReturn(mockTopologyAwareDialect); +// +// when(mockServicesContainer.getConnectionPluginManager()).thenReturn(mockConnectionPluginManager); +// when(mockService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); +// // noinspection unchecked +// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); +// +// this.plugin = new AwsSecretsManagerConnectionPlugin( +// mockService, +// TEST_PROPS, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest); +// +// when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) +// .thenReturn(mockTopologyAwareDialect); +// +// when(mockService.getHostSpecBuilder()).thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// AwsSecretsManagerCacheHolder.clearCache(); +// TEST_PROPS.clear(); +// } +// +// /** +// * The plugin will successfully open a connection with a cached secret. +// */ +// @Test +// public void testConnectWithCachedSecrets() throws SQLException { +// // Add initial cached secret to be used for a connection. +// AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); +// +// this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); +// +// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); +// verify(this.connectFunc).call(); +// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); +// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); +// } +// +// /** +// * The plugin will attempt to open a connection with an empty secret cache. The plugin will fetch the secret from the +// * AWS Secrets Manager. +// */ +// @Test +// public void testConnectWithNewSecrets() throws SQLException { +// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) +// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); +// +// this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); +// +// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); +// verify(connectFunc).call(); +// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); +// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); +// } +// +// @ParameterizedTest +// @MethodSource("missingArguments") +// public void testMissingRequiredParameters(final Properties properties) { +// assertThrows(RuntimeException.class, () -> new AwsSecretsManagerConnectionPlugin( +// mockService, +// properties, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest)); +// } +// +// /** +// * The plugin will attempt to open a connection with a cached secret, but it will fail with a generic SQL exception. +// * In this case, the plugin will rethrow the error back to the user. +// */ +// @Test +// public void testFailedInitialConnectionWithUnhandledError() throws SQLException { +// AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); +// final SQLException failedFirstConnectionGenericException = new SQLException(TEST_SQL_ERROR, UNHANDLED_ERROR_CODE); +// doThrow(failedFirstConnectionGenericException).when(connectFunc).call(); +// +// final SQLException connectionFailedException = assertThrows( +// SQLException.class, +// () -> this.plugin.connect( +// TEST_PG_PROTOCOL, +// TEST_HOSTSPEC, +// TEST_PROPS, +// true, +// this.connectFunc)); +// +// assertEquals(TEST_SQL_ERROR, connectionFailedException.getMessage()); +// verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); +// verify(connectFunc).call(); +// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); +// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); +// } +// +// /** +// * The plugin will attempt to open a connection with a cached secret, but it will fail with an access error. In this +// * case, the plugin will fetch the secret and will retry the connection. +// */ +// @ParameterizedTest +// @MethodSource("provideExceptionCodeForDifferentDrivers") +// public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( +// String accessError, +// String protocol, +// ExceptionHandler exceptionHandler) throws SQLException { +// this.plugin = new AwsSecretsManagerConnectionPlugin( +// getPluginService(protocol), +// TEST_PROPS, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest); +// +// // Fail the initial connection attempt with cached secret. +// // Second attempt should be successful. +// AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); +// final SQLException failedFirstConnectionAccessException = new SQLException(TEST_SQL_ERROR, +// accessError); +// doThrow(failedFirstConnectionAccessException).when(connectFunc).call(); +// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) +// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); +// +// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(exceptionHandler); +// +// assertThrows( +// SQLException.class, +// () -> this.plugin.connect( +// TEST_PG_PROTOCOL, +// TEST_HOSTSPEC, +// TEST_PROPS, +// true, +// this.connectFunc)); +// +// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); +// verify(connectFunc, times(2)).call(); +// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); +// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); +// } +// +// private @NotNull PluginServiceImpl getPluginService(String protocol) throws SQLException { +// return new PluginServiceImpl( +// mockServicesContainer, +// new ExceptionManager(), +// TEST_PROPS, +// "url", +// protocol, +// mockDialectManager, +// mockTargetDriverDialect, +// configurationProfile, +// mockSessionStateService); +// } +// +// /** +// * The plugin will attempt to open a connection after fetching a secret, but it will fail because the returned secret +// * could not be parsed. +// */ +// @Test +// public void testFailedToReadSecrets() throws SQLException { +// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) +// .thenReturn(INVALID_GET_SECRET_VALUE_RESPONSE); +// +// final SQLException readSecretsFailedException = +// assertThrows( +// SQLException.class, +// () -> this.plugin.connect( +// TEST_PG_PROTOCOL, +// TEST_HOSTSPEC, +// TEST_PROPS, +// true, +// this.connectFunc)); +// +// assertEquals( +// readSecretsFailedException.getMessage(), +// Messages.get( +// "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); +// assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); +// verify(this.connectFunc, never()).call(); +// } +// +// /** +// * The plugin will attempt to open a connection after fetching a secret, but it will fail because an exception was +// * thrown by the AWS Secrets Manager. +// */ +// @Test +// public void testFailedToGetSecrets() throws SQLException { +// doThrow(SecretsManagerException.class).when(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); +// +// final SQLException getSecretsFailedException = +// assertThrows( +// SQLException.class, +// () -> this.plugin.connect( +// TEST_PG_PROTOCOL, +// TEST_HOSTSPEC, +// TEST_PROPS, +// true, +// this.connectFunc)); +// +// assertEquals( +// getSecretsFailedException.getMessage(), +// Messages.get( +// "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); +// assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); +// verify(this.connectFunc, never()).call(); +// } +// +// @ParameterizedTest +// @ValueSource(strings = {"28000", "28P01"}) +// public void testFailedInitialConnectionWithWrappedGenericError(final String accessError) throws SQLException { +// this.plugin = new AwsSecretsManagerConnectionPlugin( +// getPluginService(TEST_PG_PROTOCOL), +// TEST_PROPS, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest); +// +// // Fail the initial connection attempt with a wrapped exception. +// // Second attempt should be successful. +// final SQLException targetException = new SQLException(TEST_SQL_ERROR, accessError); +// final SQLException wrappedException = new SQLException(targetException); +// doThrow(wrappedException).when(connectFunc).call(); +// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) +// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); +// +// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); +// +// assertThrows( +// SQLException.class, +// () -> this.plugin.connect( +// TEST_PG_PROTOCOL, +// TEST_HOSTSPEC, +// TEST_PROPS, +// true, +// this.connectFunc)); +// +// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(connectFunc).call(); +// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); +// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); +// } +// +// @Test +// public void testConnectWithWrappedMySQLException() throws SQLException { +// this.plugin = new AwsSecretsManagerConnectionPlugin( +// getPluginService(TEST_MYSQL_PROTOCOL), +// TEST_PROPS, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest); +// +// final CJException targetException = new CJException("28000"); +// final SQLException wrappedException = new SQLException(targetException); +// +// doThrow(wrappedException).when(connectFunc).call(); +// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) +// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); +// +// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); +// +// assertThrows( +// SQLException.class, +// () -> this.plugin.connect( +// TEST_MYSQL_PROTOCOL, +// TEST_HOSTSPEC, +// TEST_PROPS, +// true, +// this.connectFunc)); +// +// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(connectFunc).call(); +// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); +// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); +// } +// +// @Test +// public void testConnectWithWrappedPostgreSQLException() throws SQLException { +// this.plugin = new AwsSecretsManagerConnectionPlugin( +// getPluginService(TEST_PG_PROTOCOL), +// TEST_PROPS, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest); +// +// final PSQLException targetException = new PSQLException("login error", PSQLState.INVALID_PASSWORD, null); +// final SQLException wrappedException = new SQLException(targetException); +// +// doThrow(wrappedException).when(connectFunc).call(); +// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) +// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); +// +// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); +// +// assertThrows( +// SQLException.class, +// () -> this.plugin.connect( +// TEST_PG_PROTOCOL, +// TEST_HOSTSPEC, +// TEST_PROPS, +// true, +// this.connectFunc)); +// +// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); +// verify(connectFunc).call(); +// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); +// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); +// } +// +// @ParameterizedTest +// @MethodSource("arnArguments") +// public void testConnectViaARN(final String arn, final Region expectedRegionParsedFromARN) +// throws SQLException { +// final Properties props = new Properties(); +// +// SECRET_ID_PROPERTY.set(props, arn); +// +// this.plugin = spy(new AwsSecretsManagerConnectionPlugin( +// new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), +// props, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest)); +// +// final Pair secret = this.plugin.secretKey; +// assertEquals(expectedRegionParsedFromARN, Region.of(secret.getValue2())); +// } +// +// @ParameterizedTest +// @MethodSource("arnArguments") +// public void testConnectionWithRegionParameterAndARN(final String arn, final Region regionParsedFromARN) +// throws SQLException { +// final Region expectedRegion = Region.US_ISO_EAST_1; +// +// final Properties props = new Properties(); +// SECRET_ID_PROPERTY.set(props, arn); +// REGION_PROPERTY.set(props, expectedRegion.toString()); +// +// this.plugin = spy(new AwsSecretsManagerConnectionPlugin( +// new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), +// props, +// (host, r) -> mockSecretsManagerClient, +// (id) -> mockGetValueRequest)); +// +// final Pair secret = this.plugin.secretKey; +// // The region specified in `secretsManagerRegion` should override the region parsed from ARN. +// assertNotEquals(regionParsedFromARN, Region.of(secret.getValue2())); +// assertEquals(expectedRegion, Region.of(secret.getValue2())); +// } +// +// private static Stream provideExceptionCodeForDifferentDrivers() { +// return Stream.of( +// Arguments.of("28000", TEST_MYSQL_PROTOCOL, new MySQLExceptionHandler()), +// Arguments.of("28P01", TEST_PG_PROTOCOL, new PgExceptionHandler()) +// ); +// } +// +// private static Stream arnArguments() { +// return Stream.of( +// Arguments.of("arn:aws:secretsmanager:us-east-2:123456789012:secret:foo", Region.US_EAST_2), +// Arguments.of("arn:aws:secretsmanager:us-west-1:123456789012:secret:boo", Region.US_WEST_1), +// Arguments.of( +// "arn:aws:secretsmanager:us-east-2:123456789012:secret:rds!cluster-bar-foo", +// Region.US_EAST_2) +// ); +// } +// +// private static Stream missingArguments() { +// final Properties missingId = new Properties(); +// REGION_PROPERTY.set(missingId, TEST_REGION); +// +// final Properties missingRegion = new Properties(); +// SECRET_ID_PROPERTY.set(missingRegion, TEST_SECRET_ID); +// +// return Stream.of( +// Arguments.of(missingId), +// Arguments.of(missingRegion) +// ); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java index d8fac47be..d44e081b0 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java @@ -1,138 +1,138 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Properties; -import java.util.stream.Stream; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.ConnectionProviderManager; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.PluginManagerService; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.util.telemetry.GaugeCallable; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryGauge; - -class DefaultConnectionPluginTest { - - private DefaultConnectionPlugin plugin; - - @Mock PluginService pluginService; - @Mock ConnectionProvider connectionProvider; - @Mock PluginManagerService pluginManagerService; - @Mock JdbcCallable mockSqlFunction; - @Mock JdbcCallable mockConnectFunction; - @Mock Connection conn; - @Mock Connection oldConn; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock TelemetryContext mockTelemetryContext; - @Mock TelemetryCounter mockTelemetryCounter; - @Mock TelemetryGauge mockTelemetryGauge; - @Mock ConnectionProviderManager mockConnectionProviderManager; - @Mock HostSpec mockHostSpec; - - - private AutoCloseable closeable; - - @BeforeEach - void setUp() { - closeable = MockitoAnnotations.openMocks(this); - - when(pluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); - // noinspection unchecked - when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any())) - .thenReturn(connectionProvider); - - plugin = new DefaultConnectionPlugin( - pluginService, connectionProvider, null, pluginManagerService, mockConnectionProviderManager); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - } - - @ParameterizedTest - @MethodSource("multiStatementQueries") - void testParseMultiStatementQueries(final String sql, final List expected) { - final List actual = plugin.parseMultiStatementQueries(sql); - assertEquals(expected, actual); - } - - @Test - void testExecute_closeCurrentConnection() throws SQLException { - when(this.pluginService.getCurrentConnection()).thenReturn(conn); - plugin.execute(Void.class, SQLException.class, conn, "Connection.close", mockSqlFunction, new Object[]{}); - verify(pluginManagerService, times(1)).setInTransaction(false); - } - - @Test - void testExecute_closeOldConnection() throws SQLException { - when(this.pluginService.getCurrentConnection()).thenReturn(conn); - plugin.execute(Void.class, SQLException.class, oldConn, "Connection.close", mockSqlFunction, new Object[]{}); - verify(pluginManagerService, never()).setInTransaction(anyBoolean()); - } - - @Test - void testConnect() throws SQLException { - plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction); - verify(connectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any(), any()); - verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any()); - } - - private static Stream multiStatementQueries() { - return Stream.of( - Arguments.of("", new ArrayList()), - Arguments.of(null, new ArrayList()), - Arguments.of(" ", new ArrayList()), - Arguments.of("some \t \r \n query;", Collections.singletonList("some query")), - Arguments.of("some\t\t\r\n query;query2", Arrays.asList("some query", "query2")) - ); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyBoolean; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.atLeastOnce; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.ArrayList; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.List; +// import java.util.Properties; +// import java.util.stream.Stream; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.Arguments; +// import org.junit.jupiter.params.provider.MethodSource; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.ConnectionProvider; +// import software.amazon.jdbc.ConnectionProviderManager; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.PluginManagerService; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.util.telemetry.GaugeCallable; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.util.telemetry.TelemetryGauge; +// +// class DefaultConnectionPluginTest { +// +// private DefaultConnectionPlugin plugin; +// +// @Mock PluginService pluginService; +// @Mock ConnectionProvider connectionProvider; +// @Mock PluginManagerService pluginManagerService; +// @Mock JdbcCallable mockSqlFunction; +// @Mock JdbcCallable mockConnectFunction; +// @Mock Connection conn; +// @Mock Connection oldConn; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock TelemetryCounter mockTelemetryCounter; +// @Mock TelemetryGauge mockTelemetryGauge; +// @Mock ConnectionProviderManager mockConnectionProviderManager; +// @Mock HostSpec mockHostSpec; +// +// +// private AutoCloseable closeable; +// +// @BeforeEach +// void setUp() { +// closeable = MockitoAnnotations.openMocks(this); +// +// when(pluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); +// // noinspection unchecked +// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); +// when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any())) +// .thenReturn(connectionProvider); +// +// plugin = new DefaultConnectionPlugin( +// pluginService, connectionProvider, null, pluginManagerService, mockConnectionProviderManager); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @ParameterizedTest +// @MethodSource("multiStatementQueries") +// void testParseMultiStatementQueries(final String sql, final List expected) { +// final List actual = plugin.parseMultiStatementQueries(sql); +// assertEquals(expected, actual); +// } +// +// @Test +// void testExecute_closeCurrentConnection() throws SQLException { +// when(this.pluginService.getCurrentConnection()).thenReturn(conn); +// plugin.execute(Void.class, SQLException.class, conn, "Connection.close", mockSqlFunction, new Object[]{}); +// verify(pluginManagerService, times(1)).setInTransaction(false); +// } +// +// @Test +// void testExecute_closeOldConnection() throws SQLException { +// when(this.pluginService.getCurrentConnection()).thenReturn(conn); +// plugin.execute(Void.class, SQLException.class, oldConn, "Connection.close", mockSqlFunction, new Object[]{}); +// verify(pluginManagerService, never()).setInTransaction(anyBoolean()); +// } +// +// @Test +// void testConnect() throws SQLException { +// plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction); +// verify(connectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any(), any()); +// verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any()); +// } +// +// private static Stream multiStatementQueries() { +// return Stream.of( +// Arguments.of("", new ArrayList()), +// Arguments.of(null, new ArrayList()), +// Arguments.of(" ", new ArrayList()), +// Arguments.of("some \t \r \n query;", Collections.singletonList("some query")), +// Arguments.of("some\t\t\r\n query;query2", Arrays.asList("some query", "query2")) +// ); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java index 0d41c5f72..e465df7f9 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java @@ -1,159 +1,159 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.customendpoint; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; - -import java.sql.Connection; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.HashSet; -import java.util.Properties; -import java.util.function.BiFunction; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.rds.RdsClient; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.monitoring.MonitorService; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; - -public class CustomEndpointPluginTest { - private final String writerClusterUrl = "writer.cluster-XYZ.us-east-1.rds.amazonaws.com"; - private final String customEndpointUrl = "custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; - - private AutoCloseable closeable; - private final Properties props = new Properties(); - private final HostAvailabilityStrategy availabilityStrategy = new SimpleHostAvailabilityStrategy(); - private final HostSpecBuilder hostSpecBuilder = new HostSpecBuilder(availabilityStrategy); - private final HostSpec writerClusterHost = hostSpecBuilder.host(writerClusterUrl).build(); - private final HostSpec host = hostSpecBuilder.host(customEndpointUrl).build(); - - @Mock private FullServicesContainer mockServicesContainer; - @Mock private PluginService mockPluginService; - @Mock private MonitorService mockMonitorService; - @Mock private BiFunction mockRdsClientFunc; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock private TelemetryCounter mockTelemetryCounter; - @Mock private JdbcCallable mockConnectFunc; - @Mock private JdbcCallable mockJdbcMethodFunc; - @Mock private Connection mockConnection; - @Mock private CustomEndpointMonitor mockMonitor; - @Mock TargetDriverDialect mockTargetDriverDialect; - - - @BeforeEach - public void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - - when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); - when(mockServicesContainer.getMonitorService()).thenReturn(mockMonitorService); - when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.createCounter(any(String.class))).thenReturn(mockTelemetryCounter); - when(mockMonitor.hasCustomEndpointInfo()).thenReturn(true); - when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); - when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - props.clear(); - } - - private CustomEndpointPlugin getSpyPlugin() throws SQLException { - CustomEndpointPlugin plugin = new CustomEndpointPlugin(mockServicesContainer, props, mockRdsClientFunc); - CustomEndpointPlugin spyPlugin = spy(plugin); - doReturn(mockMonitor).when(spyPlugin).createMonitorIfAbsent(any(Properties.class)); - return spyPlugin; - } - - @Test - public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { - CustomEndpointPlugin spyPlugin = getSpyPlugin(); - - spyPlugin.connect("", writerClusterHost, props, true, mockConnectFunc); - - verify(mockConnectFunc, times(1)).call(); - verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); - } - - @Test - public void testConnect_monitorCreated() throws SQLException { - CustomEndpointPlugin spyPlugin = getSpyPlugin(); - - spyPlugin.connect("", host, props, true, mockConnectFunc); - - verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); - verify(mockConnectFunc, times(1)).call(); - } - - @Test - public void testConnect_timeoutWaitingForInfo() throws SQLException { - WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.set(props, "1"); - CustomEndpointPlugin spyPlugin = getSpyPlugin(); - when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); - - assertThrows(SQLException.class, () -> spyPlugin.connect("", host, props, true, mockConnectFunc)); - - verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); - verify(mockConnectFunc, never()).call(); - } - - @Test - public void testExecute_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { - CustomEndpointPlugin spyPlugin = getSpyPlugin(); - - spyPlugin.execute( - Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); - - verify(mockJdbcMethodFunc, times(1)).call(); - verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); - } - - @Test - public void testExecute_monitorCreated() throws SQLException { - CustomEndpointPlugin spyPlugin = getSpyPlugin(); - spyPlugin.customEndpointHostSpec = host; - - spyPlugin.execute( - Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); - - verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); - verify(mockJdbcMethodFunc, times(1)).call(); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.customendpoint; +// +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// import static software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.HashSet; +// import java.util.Properties; +// import java.util.function.BiFunction; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.awssdk.regions.Region; +// import software.amazon.awssdk.services.rds.RdsClient; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.monitoring.MonitorService; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// +// public class CustomEndpointPluginTest { +// private final String writerClusterUrl = "writer.cluster-XYZ.us-east-1.rds.amazonaws.com"; +// private final String customEndpointUrl = "custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; +// +// private AutoCloseable closeable; +// private final Properties props = new Properties(); +// private final HostAvailabilityStrategy availabilityStrategy = new SimpleHostAvailabilityStrategy(); +// private final HostSpecBuilder hostSpecBuilder = new HostSpecBuilder(availabilityStrategy); +// private final HostSpec writerClusterHost = hostSpecBuilder.host(writerClusterUrl).build(); +// private final HostSpec host = hostSpecBuilder.host(customEndpointUrl).build(); +// +// @Mock private FullServicesContainer mockServicesContainer; +// @Mock private PluginService mockPluginService; +// @Mock private MonitorService mockMonitorService; +// @Mock private BiFunction mockRdsClientFunc; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock private TelemetryCounter mockTelemetryCounter; +// @Mock private JdbcCallable mockConnectFunc; +// @Mock private JdbcCallable mockJdbcMethodFunc; +// @Mock private Connection mockConnection; +// @Mock private CustomEndpointMonitor mockMonitor; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// +// +// @BeforeEach +// public void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// +// when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); +// when(mockServicesContainer.getMonitorService()).thenReturn(mockMonitorService); +// when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.createCounter(any(String.class))).thenReturn(mockTelemetryCounter); +// when(mockMonitor.hasCustomEndpointInfo()).thenReturn(true); +// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); +// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// props.clear(); +// } +// +// private CustomEndpointPlugin getSpyPlugin() throws SQLException { +// CustomEndpointPlugin plugin = new CustomEndpointPlugin(mockServicesContainer, props, mockRdsClientFunc); +// CustomEndpointPlugin spyPlugin = spy(plugin); +// doReturn(mockMonitor).when(spyPlugin).createMonitorIfAbsent(any(Properties.class)); +// return spyPlugin; +// } +// +// @Test +// public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { +// CustomEndpointPlugin spyPlugin = getSpyPlugin(); +// +// spyPlugin.connect("", writerClusterHost, props, true, mockConnectFunc); +// +// verify(mockConnectFunc, times(1)).call(); +// verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); +// } +// +// @Test +// public void testConnect_monitorCreated() throws SQLException { +// CustomEndpointPlugin spyPlugin = getSpyPlugin(); +// +// spyPlugin.connect("", host, props, true, mockConnectFunc); +// +// verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); +// verify(mockConnectFunc, times(1)).call(); +// } +// +// @Test +// public void testConnect_timeoutWaitingForInfo() throws SQLException { +// WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.set(props, "1"); +// CustomEndpointPlugin spyPlugin = getSpyPlugin(); +// when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); +// +// assertThrows(SQLException.class, () -> spyPlugin.connect("", host, props, true, mockConnectFunc)); +// +// verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); +// verify(mockConnectFunc, never()).call(); +// } +// +// @Test +// public void testExecute_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { +// CustomEndpointPlugin spyPlugin = getSpyPlugin(); +// +// spyPlugin.execute( +// Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); +// +// verify(mockJdbcMethodFunc, times(1)).call(); +// verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); +// } +// +// @Test +// public void testExecute_monitorCreated() throws SQLException { +// CustomEndpointPlugin spyPlugin = getSpyPlugin(); +// spyPlugin.customEndpointHostSpec = host; +// +// spyPlugin.execute( +// Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); +// +// verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); +// verify(mockJdbcMethodFunc, times(1)).call(); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java index f4cc60fec..377b511eb 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java @@ -1,356 +1,356 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.dev; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.ConnectionPluginManager; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.dialect.DialectCodes; -import software.amazon.jdbc.dialect.DialectManager; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.FullServicesContainerImpl; -import software.amazon.jdbc.util.monitoring.MonitorService; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.wrapper.ConnectionWrapper; - -@SuppressWarnings({"resource"}) -public class DeveloperConnectionPluginTest { - private FullServicesContainer servicesContainer; - @Mock StorageService mockStorageService; - @Mock MonitorService mockMonitorService; - @Mock ConnectionProvider mockConnectionProvider; - @Mock Connection mockConnection; - @Mock ConnectionPluginManager mockConnectionPluginManager; - @Mock ExceptionSimulatorConnectCallback mockConnectCallback; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock TelemetryContext mockTelemetryContext; - @Mock TargetDriverDialect mockTargetDriverDialect; - - private AutoCloseable closeable; - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - } - - @BeforeEach - void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - servicesContainer = new FullServicesContainerImpl( - mockStorageService, mockMonitorService, mockConnectionProvider, mockTelemetryFactory); - - when(mockConnectionProvider.connect(any(), any(), any(), any(), any())).thenReturn(mockConnection); - when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())).thenReturn(null); - - when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); - when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); - } - - @Test - @SuppressWarnings("try") - public void test_RaiseException() throws SQLException { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - try (ConnectionWrapper wrapper = new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)) { - - ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); - assertNotNull(simulator); - - assertDoesNotThrow(() -> wrapper.createStatement()); - - final RuntimeException runtimeException = new RuntimeException("test"); - simulator.raiseExceptionOnNextCall(runtimeException); - Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); - assertSame(runtimeException, thrownException); - - assertDoesNotThrow(() -> wrapper.createStatement()); - } - } - - @Test - public void test_RaiseExceptionForMethodName() throws SQLException { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - try (ConnectionWrapper wrapper = new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)) { - - ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); - assertNotNull(simulator); - - assertDoesNotThrow(() -> wrapper.createStatement()); - - final RuntimeException runtimeException = new RuntimeException("test"); - simulator.raiseExceptionOnNextCall("Connection.createStatement", runtimeException); - Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); - assertSame(runtimeException, thrownException); - - assertDoesNotThrow(() -> wrapper.createStatement()); - } - } - - @Test - public void test_RaiseExceptionForAnyMethodName() throws SQLException { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - try (ConnectionWrapper wrapper = new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)) { - - ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); - assertNotNull(simulator); - - assertDoesNotThrow(() -> wrapper.createStatement()); - - final RuntimeException runtimeException = new RuntimeException("test"); - simulator.raiseExceptionOnNextCall("*", runtimeException); - Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); - assertSame(runtimeException, thrownException); - - assertDoesNotThrow(() -> wrapper.createStatement()); - } - } - - @Test - public void test_RaiseExceptionForWrongMethodName() throws SQLException { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - try (ConnectionWrapper wrapper = new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)) { - - ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); - assertNotNull(simulator); - - assertDoesNotThrow(() -> wrapper.createStatement()); - - final RuntimeException runtimeException = new RuntimeException("test"); - simulator.raiseExceptionOnNextCall("Connection.isClosed", runtimeException); - assertDoesNotThrow(() -> wrapper.createStatement()); - - Throwable thrownException = assertThrows(RuntimeException.class, wrapper::isClosed); - assertSame(runtimeException, thrownException); - - assertDoesNotThrow(() -> wrapper.createStatement()); - } - } - - @Test - public void test_RaiseExpectedExceptionClass() throws SQLException { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - try (ConnectionWrapper wrapper = new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)) { - - ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); - assertNotNull(simulator); - - assertDoesNotThrow(() -> wrapper.createStatement()); - - final SQLException sqlException = new SQLException("test"); - simulator.raiseExceptionOnNextCall(sqlException); - Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); - assertSame(sqlException, thrownException); - - assertDoesNotThrow(() -> wrapper.createStatement()); - } - } - - @Test - public void test_RaiseUnexpectedExceptionClass() throws SQLException { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - try (ConnectionWrapper wrapper = new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)) { - - ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); - assertNotNull(simulator); - - assertDoesNotThrow(() -> wrapper.createStatement()); - - final Exception exception = new Exception("test"); - simulator.raiseExceptionOnNextCall(exception); - Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); - assertNotNull(thrownException); - assertNotSame(exception, thrownException); - assertInstanceOf(SQLException.class, thrownException); - assertNotNull(thrownException.getCause()); - assertSame(thrownException.getCause(), exception); - - assertDoesNotThrow(() -> wrapper.createStatement()); - } - } - - @Test - public void test_RaiseExceptionOnConnect() { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - - final SQLException exception = new SQLException("test"); - ExceptionSimulatorManager.raiseExceptionOnNextConnect(exception); - - Throwable thrownException = assertThrows( - SQLException.class, - () -> new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)); - assertSame(exception, thrownException); - - assertDoesNotThrow( - () -> new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)); - } - - @Test - public void test_NoExceptionOnConnectWithCallback() { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - - ExceptionSimulatorManager.setCallback(mockConnectCallback); - - assertDoesNotThrow( - () -> new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)); - } - - @Test - public void test_RaiseExceptionOnConnectWithCallback() { - - final Properties props = new Properties(); - props.put(PropertyDefinition.PLUGINS.name, "dev"); - props.put(DialectManager.DIALECT.name, DialectCodes.PG); - - final SQLException exception = new SQLException("test"); - when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())) - .thenReturn(exception) - .thenReturn(null); - ExceptionSimulatorManager.setCallback(mockConnectCallback); - - Throwable thrownException = assertThrows( - SQLException.class, - () -> new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)); - assertSame(exception, thrownException); - - assertDoesNotThrow( - () -> new ConnectionWrapper( - servicesContainer, - props, - "any-protocol://any-host/", - mockConnectionProvider, - null, - mockTargetDriverDialect, - null)); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.dev; +// +// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +// import static org.junit.jupiter.api.Assertions.assertInstanceOf; +// import static org.junit.jupiter.api.Assertions.assertNotNull; +// import static org.junit.jupiter.api.Assertions.assertNotSame; +// import static org.junit.jupiter.api.Assertions.assertSame; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyBoolean; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.ConnectionPluginManager; +// import software.amazon.jdbc.ConnectionProvider; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.dialect.DialectCodes; +// import software.amazon.jdbc.dialect.DialectManager; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.FullServicesContainerImpl; +// import software.amazon.jdbc.util.monitoring.MonitorService; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.wrapper.ConnectionWrapper; +// +// @SuppressWarnings({"resource"}) +// public class DeveloperConnectionPluginTest { +// private FullServicesContainer servicesContainer; +// @Mock StorageService mockStorageService; +// @Mock MonitorService mockMonitorService; +// @Mock ConnectionProvider mockConnectionProvider; +// @Mock Connection mockConnection; +// @Mock ConnectionPluginManager mockConnectionPluginManager; +// @Mock ExceptionSimulatorConnectCallback mockConnectCallback; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// +// private AutoCloseable closeable; +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @BeforeEach +// void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// servicesContainer = new FullServicesContainerImpl( +// mockStorageService, mockMonitorService, mockConnectionProvider, mockTelemetryFactory); +// +// when(mockConnectionProvider.connect(any(), any())).thenReturn(mockConnection); +// when(mockConnectCallback.getExceptionToRaise(any(), any(), anyBoolean())).thenReturn(null); +// +// when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); +// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); +// } +// +// @Test +// @SuppressWarnings("try") +// public void test_RaiseException() throws SQLException { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// try (ConnectionWrapper wrapper = new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)) { +// +// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); +// assertNotNull(simulator); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// +// final RuntimeException runtimeException = new RuntimeException("test"); +// simulator.raiseExceptionOnNextCall(runtimeException); +// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); +// assertSame(runtimeException, thrownException); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// } +// } +// +// @Test +// public void test_RaiseExceptionForMethodName() throws SQLException { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// try (ConnectionWrapper wrapper = new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)) { +// +// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); +// assertNotNull(simulator); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// +// final RuntimeException runtimeException = new RuntimeException("test"); +// simulator.raiseExceptionOnNextCall("Connection.createStatement", runtimeException); +// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); +// assertSame(runtimeException, thrownException); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// } +// } +// +// @Test +// public void test_RaiseExceptionForAnyMethodName() throws SQLException { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// try (ConnectionWrapper wrapper = new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)) { +// +// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); +// assertNotNull(simulator); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// +// final RuntimeException runtimeException = new RuntimeException("test"); +// simulator.raiseExceptionOnNextCall("*", runtimeException); +// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); +// assertSame(runtimeException, thrownException); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// } +// } +// +// @Test +// public void test_RaiseExceptionForWrongMethodName() throws SQLException { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// try (ConnectionWrapper wrapper = new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)) { +// +// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); +// assertNotNull(simulator); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// +// final RuntimeException runtimeException = new RuntimeException("test"); +// simulator.raiseExceptionOnNextCall("Connection.isClosed", runtimeException); +// assertDoesNotThrow(() -> wrapper.createStatement()); +// +// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::isClosed); +// assertSame(runtimeException, thrownException); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// } +// } +// +// @Test +// public void test_RaiseExpectedExceptionClass() throws SQLException { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// try (ConnectionWrapper wrapper = new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)) { +// +// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); +// assertNotNull(simulator); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// +// final SQLException sqlException = new SQLException("test"); +// simulator.raiseExceptionOnNextCall(sqlException); +// Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); +// assertSame(sqlException, thrownException); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// } +// } +// +// @Test +// public void test_RaiseUnexpectedExceptionClass() throws SQLException { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// try (ConnectionWrapper wrapper = new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)) { +// +// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); +// assertNotNull(simulator); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// +// final Exception exception = new Exception("test"); +// simulator.raiseExceptionOnNextCall(exception); +// Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); +// assertNotNull(thrownException); +// assertNotSame(exception, thrownException); +// assertInstanceOf(SQLException.class, thrownException); +// assertNotNull(thrownException.getCause()); +// assertSame(thrownException.getCause(), exception); +// +// assertDoesNotThrow(() -> wrapper.createStatement()); +// } +// } +// +// @Test +// public void test_RaiseExceptionOnConnect() { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// +// final SQLException exception = new SQLException("test"); +// ExceptionSimulatorManager.raiseExceptionOnNextConnect(exception); +// +// Throwable thrownException = assertThrows( +// SQLException.class, +// () -> new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)); +// assertSame(exception, thrownException); +// +// assertDoesNotThrow( +// () -> new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)); +// } +// +// @Test +// public void test_NoExceptionOnConnectWithCallback() { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// +// ExceptionSimulatorManager.setCallback(mockConnectCallback); +// +// assertDoesNotThrow( +// () -> new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)); +// } +// +// @Test +// public void test_RaiseExceptionOnConnectWithCallback() { +// +// final Properties props = new Properties(); +// props.put(PropertyDefinition.PLUGINS.name, "dev"); +// props.put(DialectManager.DIALECT.name, DialectCodes.PG); +// +// final SQLException exception = new SQLException("test"); +// when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())) +// .thenReturn(exception) +// .thenReturn(null); +// ExceptionSimulatorManager.setCallback(mockConnectCallback); +// +// Throwable thrownException = assertThrows( +// SQLException.class, +// () -> new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)); +// assertSame(exception, thrownException); +// +// assertDoesNotThrow( +// () -> new ConnectionWrapper( +// servicesContainer, +// props, +// "any-protocol://any-host/", +// mockConnectionProvider, +// null, +// mockTargetDriverDialect, +// null)); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java index 63320e95d..8b8d8dbb5 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java @@ -1,344 +1,344 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.efm; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anySet; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.atMostOnce; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Supplier; -import java.util.stream.Stream; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.NodeChangeOptions; -import software.amazon.jdbc.OldConnectionSuggestedAction; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.HostAvailability; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.RdsUrlType; -import software.amazon.jdbc.util.RdsUtils; - -class HostMonitoringConnectionPluginTest { - - static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; - static final String MONITOR_METHOD_NAME = JdbcMethod.STATEMENT_EXECUTEQUERY.methodName; - static final String NO_MONITOR_METHOD_NAME = JdbcMethod.CONNECTION_ABORT.methodName; - static final int FAILURE_DETECTION_TIME = 10; - static final int FAILURE_DETECTION_INTERVAL = 100; - static final int FAILURE_DETECTION_COUNT = 5; - private static final Object[] EMPTY_ARGS = {}; - @Mock PluginService pluginService; - @Mock Dialect mockDialect; - @Mock Connection connection; - @Mock Statement statement; - @Mock ResultSet resultSet; - @Captor ArgumentCaptor stringArgumentCaptor; - Properties properties = new Properties(); - @Mock HostSpec hostSpec; - @Mock HostSpec hostSpec2; - @Mock Supplier supplier; - @Mock RdsUtils rdsUtils; - @Mock HostMonitorConnectionContext context; - @Mock ReentrantLock mockReentrantLock; - @Mock HostMonitorService monitorService; - @Mock JdbcCallable sqlFunction; - @Mock TargetDriverDialect targetDriverDialect; - - private HostMonitoringConnectionPlugin plugin; - private AutoCloseable closeable; - - /** - * Generate different sets of method arguments where one argument is null to ensure {@link - * software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin#HostMonitoringConnectionPlugin(PluginService, - * Properties)} can handle null arguments correctly. - * - * @return different sets of arguments. - */ - private static Stream generateNullArguments() { - final PluginService pluginService = mock(PluginService.class); - final Properties properties = new Properties(); - - return Stream.of( - Arguments.of(null, null), - Arguments.of(pluginService, null), - Arguments.of(null, properties)); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - } - - @BeforeEach - void init() throws Exception { - closeable = MockitoAnnotations.openMocks(this); - initDefaultMockReturns(); - properties.clear(); - } - - void initDefaultMockReturns() throws Exception { - when(supplier.get()).thenReturn(monitorService); - when(monitorService.startMonitoring( - any(Connection.class), - anySet(), - any(HostSpec.class), - any(Properties.class), - anyInt(), - anyInt(), - anyInt())) - .thenReturn(context); - when(context.getLock()).thenReturn(mockReentrantLock); - - when(pluginService.getCurrentConnection()).thenReturn(connection); - when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); - when(pluginService.getDialect()).thenReturn(mockDialect); - when(pluginService.getTargetDriverDialect()).thenReturn(targetDriverDialect); - when(targetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn( - new HashSet<>(Collections.singletonList(MONITOR_METHOD_NAME))); - when(mockDialect.getHostAliasQuery()).thenReturn("any"); - when(hostSpec.getHost()).thenReturn("host"); - when(hostSpec.getHost()).thenReturn("port"); - when(hostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); - when(hostSpec2.getHost()).thenReturn("host"); - when(hostSpec2.getHost()).thenReturn("port"); - when(hostSpec2.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); - when(connection.createStatement()).thenReturn(statement); - when(statement.executeQuery(any())).thenReturn(resultSet); - when(rdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); - - properties.put("failureDetectionEnabled", Boolean.TRUE.toString()); - properties.put("failureDetectionTime", String.valueOf(FAILURE_DETECTION_TIME)); - properties.put("failureDetectionInterval", String.valueOf(FAILURE_DETECTION_INTERVAL)); - properties.put("failureDetectionCount", String.valueOf(FAILURE_DETECTION_COUNT)); - } - - private void initializePlugin() { - plugin = new HostMonitoringConnectionPlugin(pluginService, properties, supplier, rdsUtils); - } - - @Test - void test_executeWithMonitoringDisabled() throws Exception { - properties.put("failureDetectionEnabled", Boolean.FALSE.toString()); - - initializePlugin(); - - plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - sqlFunction, - EMPTY_ARGS); - - verify(supplier, never()).get(); - verify(monitorService, never()) - .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); - verify(monitorService, never()).stopMonitoring(context); - verify(sqlFunction, times(1)).call(); - } - - @Test - void test_executeWithNoNeedToMonitor() throws Exception { - - initializePlugin(); - - plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - NO_MONITOR_METHOD_NAME, - sqlFunction, - EMPTY_ARGS); - - verify(supplier, atMostOnce()).get(); - verify(monitorService, never()) - .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); - verify(monitorService, never()).stopMonitoring(context); - verify(sqlFunction, times(1)).call(); - } - - @Test - void test_executeMonitoringEnabled() throws Exception { - - initializePlugin(); - - plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - sqlFunction, - EMPTY_ARGS); - - verify(supplier, times(1)).get(); - verify(monitorService, times(1)) - .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); - verify(monitorService, times(1)).stopMonitoring(context); - verify(sqlFunction, times(1)).call(); - } - - /** - * Tests exception being thrown in the finally block when checking connection status in the execute method. - */ - @Test - void test_executeCleanUp_whenCheckingConnectionStatus_throwsException() throws SQLException { - initializePlugin(); - - final SQLException expectedException = new SQLException("exception thrown during isClosed"); - when(context.isNodeUnhealthy()).thenReturn(true); - doThrow(expectedException).when(connection).isClosed(); - final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - sqlFunction, - EMPTY_ARGS)); - - assertEquals(expectedException, actualException); - } - - /** - * Tests exception being thrown in the finally block - * when an open connection object is detected for an unavailable node in the execute method. - */ - @Test - void test_executeCleanUp_whenAbortConnection_throwsException() throws SQLException { - initializePlugin(); - - final String errorMessage = Messages.get( - "HostMonitoringConnectionPlugin.unavailableNode", - new Object[] {"alias"}); - - when(hostSpec.asAlias()).thenReturn("alias"); - when(connection.isClosed()).thenReturn(false); - when(context.isNodeUnhealthy()).thenReturn(true); - final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - sqlFunction, - EMPTY_ARGS)); - - assertEquals(errorMessage, actualException.getMessage()); - verify(pluginService).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); - verify(connection).close(); - } - - @Test - void test_connect_exceptionRaisedDuringGenerateHostAliases() throws SQLException { - initializePlugin(); - - doThrow(new SQLException()).when(connection).createStatement(); - - // Ensure SQLException raised in `generateHostAliases` are ignored. - final Connection conn = plugin.connect("protocol", hostSpec, properties, true, () -> connection); - assertNotNull(conn); - } - - @ParameterizedTest - @MethodSource("nodeChangeOptions") - void test_notifyConnectionChanged_nodeWentDown(final NodeChangeOptions option) throws SQLException { - initializePlugin(); - plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - sqlFunction, - EMPTY_ARGS); - - final Set aliases1 = new HashSet<>(Arrays.asList("alias1", "alias2")); - final Set aliases2 = new HashSet<>(Arrays.asList("alias3", "alias4")); - when(hostSpec.asAliases()).thenReturn(aliases1); - when(hostSpec2.asAliases()).thenReturn(aliases2); - when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); - - assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); - // NodeKeys should contain {"alias1", "alias2"} - verify(monitorService).stopMonitoringForAllConnections(aliases1); - - when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec2); - assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); - // NotifyConnectionChanged should reset the monitoringHostSpec. - // NodeKeys should contain {"alias3", "alias4"} - verify(monitorService).stopMonitoringForAllConnections(aliases2); - } - - @Test - void test_releaseResources() throws SQLException { - initializePlugin(); - - // Test releaseResources when the monitor service has not been initialized. - plugin.releaseResources(); - verify(monitorService, never()).releaseResources(); - - // Test releaseResources when the monitor service has been initialized. - plugin.execute( - ResultSet.class, - SQLException.class, - MONITOR_METHOD_INVOKE_ON, - MONITOR_METHOD_NAME, - sqlFunction, - EMPTY_ARGS); - plugin.releaseResources(); - verify(monitorService).releaseResources(); - } - - static Stream nodeChangeOptions() { - return Stream.of( - Arguments.of(NodeChangeOptions.WENT_DOWN), - Arguments.of(NodeChangeOptions.NODE_DELETED) - ); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.efm; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertNotNull; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.anyInt; +// import static org.mockito.ArgumentMatchers.anySet; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.any; +// import static org.mockito.Mockito.atMostOnce; +// import static org.mockito.Mockito.doThrow; +// import static org.mockito.Mockito.mock; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.HashSet; +// import java.util.Properties; +// import java.util.Set; +// import java.util.concurrent.locks.ReentrantLock; +// import java.util.function.Supplier; +// import java.util.stream.Stream; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.params.ParameterizedTest; +// import org.junit.jupiter.params.provider.Arguments; +// import org.junit.jupiter.params.provider.MethodSource; +// import org.mockito.ArgumentCaptor; +// import org.mockito.Captor; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.JdbcMethod; +// import software.amazon.jdbc.NodeChangeOptions; +// import software.amazon.jdbc.OldConnectionSuggestedAction; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.HostAvailability; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.Messages; +// import software.amazon.jdbc.util.RdsUrlType; +// import software.amazon.jdbc.util.RdsUtils; +// +// class HostMonitoringConnectionPluginTest { +// +// static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; +// static final String MONITOR_METHOD_NAME = JdbcMethod.STATEMENT_EXECUTEQUERY.methodName; +// static final String NO_MONITOR_METHOD_NAME = JdbcMethod.CONNECTION_ABORT.methodName; +// static final int FAILURE_DETECTION_TIME = 10; +// static final int FAILURE_DETECTION_INTERVAL = 100; +// static final int FAILURE_DETECTION_COUNT = 5; +// private static final Object[] EMPTY_ARGS = {}; +// @Mock PluginService pluginService; +// @Mock Dialect mockDialect; +// @Mock Connection connection; +// @Mock Statement statement; +// @Mock ResultSet resultSet; +// @Captor ArgumentCaptor stringArgumentCaptor; +// Properties properties = new Properties(); +// @Mock HostSpec hostSpec; +// @Mock HostSpec hostSpec2; +// @Mock Supplier supplier; +// @Mock RdsUtils rdsUtils; +// @Mock HostMonitorConnectionContext context; +// @Mock ReentrantLock mockReentrantLock; +// @Mock HostMonitorService monitorService; +// @Mock JdbcCallable sqlFunction; +// @Mock TargetDriverDialect targetDriverDialect; +// +// private HostMonitoringConnectionPlugin plugin; +// private AutoCloseable closeable; +// +// /** +// * Generate different sets of method arguments where one argument is null to ensure {@link +// * software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin#HostMonitoringConnectionPlugin(PluginService, +// * Properties)} can handle null arguments correctly. +// * +// * @return different sets of arguments. +// */ +// private static Stream generateNullArguments() { +// final PluginService pluginService = mock(PluginService.class); +// final Properties properties = new Properties(); +// +// return Stream.of( +// Arguments.of(null, null), +// Arguments.of(pluginService, null), +// Arguments.of(null, properties)); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @BeforeEach +// void init() throws Exception { +// closeable = MockitoAnnotations.openMocks(this); +// initDefaultMockReturns(); +// properties.clear(); +// } +// +// void initDefaultMockReturns() throws Exception { +// when(supplier.get()).thenReturn(monitorService); +// when(monitorService.startMonitoring( +// any(Connection.class), +// anySet(), +// any(HostSpec.class), +// any(Properties.class), +// anyInt(), +// anyInt(), +// anyInt())) +// .thenReturn(context); +// when(context.getLock()).thenReturn(mockReentrantLock); +// +// when(pluginService.getCurrentConnection()).thenReturn(connection); +// when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); +// when(pluginService.getDialect()).thenReturn(mockDialect); +// when(pluginService.getTargetDriverDialect()).thenReturn(targetDriverDialect); +// when(targetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn( +// new HashSet<>(Collections.singletonList(MONITOR_METHOD_NAME))); +// when(mockDialect.getHostAliasQuery()).thenReturn("any"); +// when(hostSpec.getHost()).thenReturn("host"); +// when(hostSpec.getHost()).thenReturn("port"); +// when(hostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); +// when(hostSpec2.getHost()).thenReturn("host"); +// when(hostSpec2.getHost()).thenReturn("port"); +// when(hostSpec2.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); +// when(connection.createStatement()).thenReturn(statement); +// when(statement.executeQuery(any())).thenReturn(resultSet); +// when(rdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); +// +// properties.put("failureDetectionEnabled", Boolean.TRUE.toString()); +// properties.put("failureDetectionTime", String.valueOf(FAILURE_DETECTION_TIME)); +// properties.put("failureDetectionInterval", String.valueOf(FAILURE_DETECTION_INTERVAL)); +// properties.put("failureDetectionCount", String.valueOf(FAILURE_DETECTION_COUNT)); +// } +// +// private void initializePlugin() { +// plugin = new HostMonitoringConnectionPlugin(pluginService, properties, supplier, rdsUtils); +// } +// +// @Test +// void test_executeWithMonitoringDisabled() throws Exception { +// properties.put("failureDetectionEnabled", Boolean.FALSE.toString()); +// +// initializePlugin(); +// +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// sqlFunction, +// EMPTY_ARGS); +// +// verify(supplier, never()).get(); +// verify(monitorService, never()) +// .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); +// verify(monitorService, never()).stopMonitoring(context); +// verify(sqlFunction, times(1)).call(); +// } +// +// @Test +// void test_executeWithNoNeedToMonitor() throws Exception { +// +// initializePlugin(); +// +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// NO_MONITOR_METHOD_NAME, +// sqlFunction, +// EMPTY_ARGS); +// +// verify(supplier, atMostOnce()).get(); +// verify(monitorService, never()) +// .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); +// verify(monitorService, never()).stopMonitoring(context); +// verify(sqlFunction, times(1)).call(); +// } +// +// @Test +// void test_executeMonitoringEnabled() throws Exception { +// +// initializePlugin(); +// +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// sqlFunction, +// EMPTY_ARGS); +// +// verify(supplier, times(1)).get(); +// verify(monitorService, times(1)) +// .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); +// verify(monitorService, times(1)).stopMonitoring(context); +// verify(sqlFunction, times(1)).call(); +// } +// +// /** +// * Tests exception being thrown in the finally block when checking connection status in the execute method. +// */ +// @Test +// void test_executeCleanUp_whenCheckingConnectionStatus_throwsException() throws SQLException { +// initializePlugin(); +// +// final SQLException expectedException = new SQLException("exception thrown during isClosed"); +// when(context.isNodeUnhealthy()).thenReturn(true); +// doThrow(expectedException).when(connection).isClosed(); +// final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// sqlFunction, +// EMPTY_ARGS)); +// +// assertEquals(expectedException, actualException); +// } +// +// /** +// * Tests exception being thrown in the finally block +// * when an open connection object is detected for an unavailable node in the execute method. +// */ +// @Test +// void test_executeCleanUp_whenAbortConnection_throwsException() throws SQLException { +// initializePlugin(); +// +// final String errorMessage = Messages.get( +// "HostMonitoringConnectionPlugin.unavailableNode", +// new Object[] {"alias"}); +// +// when(hostSpec.asAlias()).thenReturn("alias"); +// when(connection.isClosed()).thenReturn(false); +// when(context.isNodeUnhealthy()).thenReturn(true); +// final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// sqlFunction, +// EMPTY_ARGS)); +// +// assertEquals(errorMessage, actualException.getMessage()); +// verify(pluginService).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); +// verify(connection).close(); +// } +// +// @Test +// void test_connect_exceptionRaisedDuringGenerateHostAliases() throws SQLException { +// initializePlugin(); +// +// doThrow(new SQLException()).when(connection).createStatement(); +// +// // Ensure SQLException raised in `generateHostAliases` are ignored. +// final Connection conn = plugin.connect("protocol", hostSpec, properties, true, () -> connection); +// assertNotNull(conn); +// } +// +// @ParameterizedTest +// @MethodSource("nodeChangeOptions") +// void test_notifyConnectionChanged_nodeWentDown(final NodeChangeOptions option) throws SQLException { +// initializePlugin(); +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// sqlFunction, +// EMPTY_ARGS); +// +// final Set aliases1 = new HashSet<>(Arrays.asList("alias1", "alias2")); +// final Set aliases2 = new HashSet<>(Arrays.asList("alias3", "alias4")); +// when(hostSpec.asAliases()).thenReturn(aliases1); +// when(hostSpec2.asAliases()).thenReturn(aliases2); +// when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); +// +// assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); +// // NodeKeys should contain {"alias1", "alias2"} +// verify(monitorService).stopMonitoringForAllConnections(aliases1); +// +// when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec2); +// assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); +// // NotifyConnectionChanged should reset the monitoringHostSpec. +// // NodeKeys should contain {"alias3", "alias4"} +// verify(monitorService).stopMonitoringForAllConnections(aliases2); +// } +// +// @Test +// void test_releaseResources() throws SQLException { +// initializePlugin(); +// +// // Test releaseResources when the monitor service has not been initialized. +// plugin.releaseResources(); +// verify(monitorService, never()).releaseResources(); +// +// // Test releaseResources when the monitor service has been initialized. +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// MONITOR_METHOD_INVOKE_ON, +// MONITOR_METHOD_NAME, +// sqlFunction, +// EMPTY_ARGS); +// plugin.releaseResources(); +// verify(monitorService).releaseResources(); +// } +// +// static Stream nodeChangeOptions() { +// return Stream.of( +// Arguments.of(NodeChangeOptions.WENT_DOWN), +// Arguments.of(NodeChangeOptions.NODE_DELETED) +// ); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java index 7b72dae6e..91028405c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java @@ -1,225 +1,225 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.federatedauth; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.time.Instant; -import java.util.Properties; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.regions.Region; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.plugin.TokenInfo; -import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; -import software.amazon.jdbc.plugin.iam.IamTokenUtility; -import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; - -class FederatedAuthPluginTest { - - private static final int DEFAULT_PORT = 1234; - private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; - private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; - private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; - private static final HostSpec HOST_SPEC = - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); - private static final String DB_USER = "iamUser"; - private static final String TEST_TOKEN = "someTestToken"; - private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); - @Mock private PluginService mockPluginService; - @Mock private Dialect mockDialect; - @Mock JdbcCallable mockLambda; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock private TelemetryContext mockTelemetryContext; - @Mock private TelemetryCounter mockTelemetryCounter; - @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; - @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; - @Mock private RdsUtils mockRdsUtils; - @Mock private IamTokenUtility mockIamTokenUtils; - @Mock private CompletableFuture completableFuture; - @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; - private Properties props; - private AutoCloseable closeable; - - @BeforeEach - public void init() throws ExecutionException, InterruptedException, SQLException { - closeable = MockitoAnnotations.openMocks(this); - props = new Properties(); - props.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); - props.setProperty(FederatedAuthPlugin.DB_USER.name, DB_USER); - FederatedAuthPlugin.clearCache(); - - when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); - when(mockIamTokenUtils.generateAuthenticationToken( - any(AwsCredentialsProvider.class), - any(Region.class), - anyString(), - anyInt(), - anyString())).thenReturn(TEST_TOKEN); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); - when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); - when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) - .thenReturn(mockAwsCredentialsProvider); - when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); - when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); - } - - @AfterEach - public void cleanUp() throws Exception { - closeable.close(); - } - - @Test - void testCachedToken() throws SQLException { - FederatedAuthPlugin plugin = - new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); - - String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; - FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testExpiredCachedToken() throws SQLException { - FederatedAuthPlugin spyPlugin = Mockito.spy( - new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - - String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; - String someExpiredToken = "someExpiredToken"; - TokenInfo expiredTokenInfo = new TokenInfo( - someExpiredToken, Instant.now().minusMillis(300000)); - FederatedAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); - - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, - Region.US_EAST_2, - HOST_SPEC.getHost(), - DEFAULT_PORT, - DB_USER); - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testNoCachedToken() throws SQLException { - FederatedAuthPlugin spyPlugin = Mockito.spy( - new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - verify(mockIamTokenUtils).generateAuthenticationToken( - mockAwsCredentialsProvider, - Region.US_EAST_2, - HOST_SPEC.getHost(), - DEFAULT_PORT, - DB_USER); - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testSpecifiedIamHostPortRegion() throws SQLException { - final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; - final int expectedPort = 9876; - final Region expectedRegion = Region.US_WEST_2; - - props.setProperty(FederatedAuthPlugin.IAM_HOST.name, expectedHost); - props.setProperty(FederatedAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); - props.setProperty(FederatedAuthPlugin.IAM_REGION.name, expectedRegion.toString()); - - final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; - FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - - FederatedAuthPlugin plugin = - new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testIdpCredentialsFallback() throws SQLException { - String expectedUser = "expectedUser"; - String expectedPassword = "expectedPassword"; - PropertyDefinition.USER.set(props, expectedUser); - PropertyDefinition.PASSWORD.set(props, expectedPassword); - - FederatedAuthPlugin plugin = - new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - - String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; - FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - assertEquals(expectedUser, FederatedAuthPlugin.IDP_USERNAME.getString(props)); - assertEquals(expectedPassword, FederatedAuthPlugin.IDP_PASSWORD.getString(props)); - } - - @Test - public void testUsingIamHost() throws SQLException { - IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); - FederatedAuthPlugin spyPlugin = Mockito.spy( - new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( - mockAwsCredentialsProvider, - Region.US_EAST_2, - IAM_HOST, - DEFAULT_PORT, - DB_USER); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.federatedauth; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyInt; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.time.Instant; +// import java.util.Properties; +// import java.util.concurrent.CompletableFuture; +// import java.util.concurrent.ExecutionException; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +// import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +// import software.amazon.awssdk.regions.Region; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.plugin.TokenInfo; +// import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; +// import software.amazon.jdbc.plugin.iam.IamTokenUtility; +// import software.amazon.jdbc.util.RdsUtils; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// +// class FederatedAuthPluginTest { +// +// private static final int DEFAULT_PORT = 1234; +// private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; +// private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; +// private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; +// private static final HostSpec HOST_SPEC = +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); +// private static final String DB_USER = "iamUser"; +// private static final String TEST_TOKEN = "someTestToken"; +// private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); +// @Mock private PluginService mockPluginService; +// @Mock private Dialect mockDialect; +// @Mock JdbcCallable mockLambda; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock private TelemetryContext mockTelemetryContext; +// @Mock private TelemetryCounter mockTelemetryCounter; +// @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; +// @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; +// @Mock private RdsUtils mockRdsUtils; +// @Mock private IamTokenUtility mockIamTokenUtils; +// @Mock private CompletableFuture completableFuture; +// @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; +// private Properties props; +// private AutoCloseable closeable; +// +// @BeforeEach +// public void init() throws ExecutionException, InterruptedException, SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// props = new Properties(); +// props.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); +// props.setProperty(FederatedAuthPlugin.DB_USER.name, DB_USER); +// FederatedAuthPlugin.clearCache(); +// +// when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); +// when(mockIamTokenUtils.generateAuthenticationToken( +// any(AwsCredentialsProvider.class), +// any(Region.class), +// anyString(), +// anyInt(), +// anyString())).thenReturn(TEST_TOKEN); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); +// when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); +// when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) +// .thenReturn(mockAwsCredentialsProvider); +// when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); +// when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); +// } +// +// @AfterEach +// public void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @Test +// void testCachedToken() throws SQLException { +// FederatedAuthPlugin plugin = +// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); +// +// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; +// FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); +// +// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testExpiredCachedToken() throws SQLException { +// FederatedAuthPlugin spyPlugin = Mockito.spy( +// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); +// +// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; +// String someExpiredToken = "someExpiredToken"; +// TokenInfo expiredTokenInfo = new TokenInfo( +// someExpiredToken, Instant.now().minusMillis(300000)); +// FederatedAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); +// +// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, +// Region.US_EAST_2, +// HOST_SPEC.getHost(), +// DEFAULT_PORT, +// DB_USER); +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testNoCachedToken() throws SQLException { +// FederatedAuthPlugin spyPlugin = Mockito.spy( +// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); +// +// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// verify(mockIamTokenUtils).generateAuthenticationToken( +// mockAwsCredentialsProvider, +// Region.US_EAST_2, +// HOST_SPEC.getHost(), +// DEFAULT_PORT, +// DB_USER); +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testSpecifiedIamHostPortRegion() throws SQLException { +// final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; +// final int expectedPort = 9876; +// final Region expectedRegion = Region.US_WEST_2; +// +// props.setProperty(FederatedAuthPlugin.IAM_HOST.name, expectedHost); +// props.setProperty(FederatedAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); +// props.setProperty(FederatedAuthPlugin.IAM_REGION.name, expectedRegion.toString()); +// +// final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; +// FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); +// +// FederatedAuthPlugin plugin = +// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); +// +// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testIdpCredentialsFallback() throws SQLException { +// String expectedUser = "expectedUser"; +// String expectedPassword = "expectedPassword"; +// PropertyDefinition.USER.set(props, expectedUser); +// PropertyDefinition.PASSWORD.set(props, expectedPassword); +// +// FederatedAuthPlugin plugin = +// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); +// +// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; +// FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); +// +// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// assertEquals(expectedUser, FederatedAuthPlugin.IDP_USERNAME.getString(props)); +// assertEquals(expectedPassword, FederatedAuthPlugin.IDP_PASSWORD.getString(props)); +// } +// +// @Test +// public void testUsingIamHost() throws SQLException { +// IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); +// FederatedAuthPlugin spyPlugin = Mockito.spy( +// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); +// +// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( +// mockAwsCredentialsProvider, +// Region.US_EAST_2, +// IAM_HOST, +// DEFAULT_PORT, +// DB_USER); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java index 910e06fe1..5e22770c2 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java @@ -1,220 +1,220 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.federatedauth; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.sql.Connection; -import java.sql.SQLException; -import java.time.Instant; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.plugin.TokenInfo; -import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; -import software.amazon.jdbc.plugin.iam.IamTokenUtility; -import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; - -class OktaAuthPluginTest { - - private static final int DEFAULT_PORT = 1234; - private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; - - private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; - private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; - private static final HostSpec HOST_SPEC = - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); - private static final String DB_USER = "iamUser"; - private static final String TEST_TOKEN = "someTestToken"; - private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); - @Mock private PluginService mockPluginService; - @Mock private Dialect mockDialect; - @Mock JdbcCallable mockLambda; - @Mock private TelemetryFactory mockTelemetryFactory; - @Mock private TelemetryContext mockTelemetryContext; - @Mock private TelemetryCounter mockTelemetryCounter; - @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; - @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; - @Mock private RdsUtils mockRdsUtils; - @Mock private IamTokenUtility mockIamTokenUtils; - - private Properties props; - private AutoCloseable closeable; - - @BeforeEach - void setUp() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - props = new Properties(); - props.setProperty(PropertyDefinition.PLUGINS.name, "okta"); - props.setProperty(OktaAuthPlugin.DB_USER.name, DB_USER); - OktaAuthPlugin.clearCache(); - - when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); - when(mockIamTokenUtils.generateAuthenticationToken( - any(AwsCredentialsProvider.class), - any(Region.class), - anyString(), - anyInt(), - anyString())).thenReturn(TEST_TOKEN); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); - when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); - when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) - .thenReturn(mockAwsCredentialsProvider); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - } - - @Test - void testCachedToken() throws SQLException { - final OktaAuthPlugin plugin = - new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - - String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; - OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testExpiredCachedToken() throws SQLException { - final OktaAuthPlugin spyPlugin = - new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - - final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; - final String someExpiredToken = "someExpiredToken"; - final TokenInfo expiredTokenInfo = new TokenInfo( - someExpiredToken, Instant.now().minusMillis(300000)); - OktaAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); - - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, - Region.US_EAST_2, - HOST_SPEC.getHost(), - DEFAULT_PORT, - DB_USER); - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testNoCachedToken() throws SQLException { - final OktaAuthPlugin spyPlugin = - new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - verify(mockIamTokenUtils).generateAuthenticationToken( - mockAwsCredentialsProvider, - Region.US_EAST_2, - HOST_SPEC.getHost(), - DEFAULT_PORT, - DB_USER); - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testSpecifiedIamHostPortRegion() throws SQLException { - final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; - final int expectedPort = 9876; - final Region expectedRegion = Region.US_WEST_2; - - props.setProperty(OktaAuthPlugin.IAM_HOST.name, expectedHost); - props.setProperty(OktaAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); - props.setProperty(OktaAuthPlugin.IAM_REGION.name, expectedRegion.toString()); - - final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; - OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - - OktaAuthPlugin plugin = - new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - @Test - void testIdpCredentialsFallback() throws SQLException { - final String expectedUser = "expectedUser"; - final String expectedPassword = "expectedPassword"; - PropertyDefinition.USER.set(props, expectedUser); - PropertyDefinition.PASSWORD.set(props, expectedPassword); - - final OktaAuthPlugin plugin = - new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - - final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; - OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - assertEquals(expectedUser, OktaAuthPlugin.IDP_USERNAME.getString(props)); - assertEquals(expectedPassword, OktaAuthPlugin.IDP_PASSWORD.getString(props)); - } - - @Test - public void testUsingIamHost() throws SQLException { - IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); - OktaAuthPlugin spyPlugin = Mockito.spy( - new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); - - assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( - mockAwsCredentialsProvider, - Region.US_EAST_2, - IAM_HOST, - DEFAULT_PORT, - DB_USER); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.federatedauth; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyInt; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.time.Instant; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +// import software.amazon.awssdk.regions.Region; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.plugin.TokenInfo; +// import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; +// import software.amazon.jdbc.plugin.iam.IamTokenUtility; +// import software.amazon.jdbc.util.RdsUtils; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// +// class OktaAuthPluginTest { +// +// private static final int DEFAULT_PORT = 1234; +// private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; +// +// private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; +// private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; +// private static final HostSpec HOST_SPEC = +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); +// private static final String DB_USER = "iamUser"; +// private static final String TEST_TOKEN = "someTestToken"; +// private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); +// @Mock private PluginService mockPluginService; +// @Mock private Dialect mockDialect; +// @Mock JdbcCallable mockLambda; +// @Mock private TelemetryFactory mockTelemetryFactory; +// @Mock private TelemetryContext mockTelemetryContext; +// @Mock private TelemetryCounter mockTelemetryCounter; +// @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; +// @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; +// @Mock private RdsUtils mockRdsUtils; +// @Mock private IamTokenUtility mockIamTokenUtils; +// +// private Properties props; +// private AutoCloseable closeable; +// +// @BeforeEach +// void setUp() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// props = new Properties(); +// props.setProperty(PropertyDefinition.PLUGINS.name, "okta"); +// props.setProperty(OktaAuthPlugin.DB_USER.name, DB_USER); +// OktaAuthPlugin.clearCache(); +// +// when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); +// when(mockIamTokenUtils.generateAuthenticationToken( +// any(AwsCredentialsProvider.class), +// any(Region.class), +// anyString(), +// anyInt(), +// anyString())).thenReturn(TEST_TOKEN); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); +// when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); +// when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) +// .thenReturn(mockAwsCredentialsProvider); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// } +// +// @Test +// void testCachedToken() throws SQLException { +// final OktaAuthPlugin plugin = +// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); +// +// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; +// OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); +// +// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testExpiredCachedToken() throws SQLException { +// final OktaAuthPlugin spyPlugin = +// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); +// +// final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; +// final String someExpiredToken = "someExpiredToken"; +// final TokenInfo expiredTokenInfo = new TokenInfo( +// someExpiredToken, Instant.now().minusMillis(300000)); +// OktaAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); +// +// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, +// Region.US_EAST_2, +// HOST_SPEC.getHost(), +// DEFAULT_PORT, +// DB_USER); +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testNoCachedToken() throws SQLException { +// final OktaAuthPlugin spyPlugin = +// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); +// +// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// verify(mockIamTokenUtils).generateAuthenticationToken( +// mockAwsCredentialsProvider, +// Region.US_EAST_2, +// HOST_SPEC.getHost(), +// DEFAULT_PORT, +// DB_USER); +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testSpecifiedIamHostPortRegion() throws SQLException { +// final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; +// final int expectedPort = 9876; +// final Region expectedRegion = Region.US_WEST_2; +// +// props.setProperty(OktaAuthPlugin.IAM_HOST.name, expectedHost); +// props.setProperty(OktaAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); +// props.setProperty(OktaAuthPlugin.IAM_REGION.name, expectedRegion.toString()); +// +// final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; +// OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); +// +// OktaAuthPlugin plugin = +// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); +// +// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// @Test +// void testIdpCredentialsFallback() throws SQLException { +// final String expectedUser = "expectedUser"; +// final String expectedPassword = "expectedPassword"; +// PropertyDefinition.USER.set(props, expectedUser); +// PropertyDefinition.PASSWORD.set(props, expectedPassword); +// +// final OktaAuthPlugin plugin = +// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); +// +// final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; +// OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); +// +// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// assertEquals(expectedUser, OktaAuthPlugin.IDP_USERNAME.getString(props)); +// assertEquals(expectedPassword, OktaAuthPlugin.IDP_PASSWORD.getString(props)); +// } +// +// @Test +// public void testUsingIamHost() throws SQLException { +// IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); +// OktaAuthPlugin spyPlugin = Mockito.spy( +// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); +// +// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); +// +// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( +// mockAwsCredentialsProvider, +// Region.US_EAST_2, +// IAM_HOST, +// DEFAULT_PORT, +// DB_USER); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java index a872ac96c..f8153430a 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java @@ -1,295 +1,295 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.iam; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.io.IOException; -import java.net.HttpURLConnection; -import java.net.URL; -import java.sql.Connection; -import java.sql.SQLException; -import java.time.Instant; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.jdbc.Driver; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.plugin.TokenInfo; -import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.telemetry.TelemetryContext; -import software.amazon.jdbc.util.telemetry.TelemetryCounter; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; -import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; - -class IamAuthConnectionPluginTest { - - private static final String GENERATED_TOKEN = "generatedToken"; - private static final String TEST_TOKEN = "testToken"; - private static final int DEFAULT_PG_PORT = 5432; - private static final int DEFAULT_MYSQL_PORT = 3306; - private static final String PG_CACHE_KEY = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" - + DEFAULT_PG_PORT + ":postgresqlUser"; - private static final String MYSQL_CACHE_KEY = "us-east-2:mysql.testdb.us-east-2.rds.amazonaws.com:" - + DEFAULT_MYSQL_PORT + ":mysqlUser"; - private static final String PG_DRIVER_PROTOCOL = "jdbc:postgresql:"; - private static final String MYSQL_DRIVER_PROTOCOL = "jdbc:mysql:"; - private static final HostSpec PG_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); - private static final HostSpec PG_HOST_SPEC_WITH_PORT = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("pg.testdb.us-east-2.rds.amazonaws.com").port(1234).build(); - private static final HostSpec PG_HOST_SPEC_WITH_REGION = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("pg.testdb.us-west-1.rds.amazonaws.com").build(); - private static final HostSpec MYSQL_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("mysql.testdb.us-east-2.rds.amazonaws.com").build(); - private Properties props; - - @Mock PluginService mockPluginService; - @Mock TelemetryFactory mockTelemetryFactory; - @Mock TelemetryCounter mockTelemetryCounter; - @Mock TelemetryContext mockTelemetryContext; - @Mock JdbcCallable mockLambda; - @Mock Dialect mockDialect; - @Mock private RdsUtils mockRdsUtils; - @Mock private IamTokenUtility mockIamTokenUtils; - private AutoCloseable closable; - - @BeforeEach - public void init() { - closable = MockitoAnnotations.openMocks(this); - props = new Properties(); - props.setProperty(PropertyDefinition.USER.name, "postgresqlUser"); - props.setProperty(PropertyDefinition.PASSWORD.name, "postgresqlPassword"); - props.setProperty(PropertyDefinition.PLUGINS.name, "iam"); - IamAuthConnectionPlugin.clearCache(); - - when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); - when(mockIamTokenUtils.generateAuthenticationToken( - any(AwsCredentialsProvider.class), - any(Region.class), - anyString(), - anyInt(), - anyString())).thenReturn(GENERATED_TOKEN); - when(mockPluginService.getDialect()).thenReturn(mockDialect); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); - when(mockTelemetryFactory.openTelemetryContext(anyString(), eq(TelemetryTraceLevel.NESTED))).thenReturn( - mockTelemetryContext); - } - - @AfterEach - public void cleanUp() throws Exception { - closable.close(); - } - - @BeforeAll - public static void registerDrivers() throws SQLException { - if (!org.postgresql.Driver.isRegistered()) { - org.postgresql.Driver.register(); - } - - if (!Driver.isRegistered()) { - Driver.register(); - } - } - - @Test - public void testPostgresConnectValidTokenInCache() throws SQLException { - IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, - new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); - - testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); - } - - @Test - public void testMySqlConnectValidTokenInCache() throws SQLException { - props.setProperty(PropertyDefinition.USER.name, "mysqlUser"); - props.setProperty(PropertyDefinition.PASSWORD.name, "mysqlPassword"); - IamAuthCacheHolder.tokenCache.put(MYSQL_CACHE_KEY, - new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_MYSQL_PORT); - - testTokenSetInProps(MYSQL_DRIVER_PROTOCOL, MYSQL_HOST_SPEC); - } - - @Test - public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLException { - final String invalidIamDefaultPort = "0"; - props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); - - final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" - + PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser"; - IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, - new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - - testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); - } - - @Test - public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() throws SQLException { - final String invalidIamDefaultPort = "0"; - props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); - - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); - - final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" - + DEFAULT_PG_PORT + ":postgresqlUser"; - IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, - new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - - testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); - } - - @Test - public void testConnectExpiredTokenInCache() throws SQLException { - IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, - new TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000))); - - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); - - testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); - } - - @Test - public void testConnectEmptyCache() throws SQLException { - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); - - testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); - } - - @Test - public void testConnectWithSpecifiedPort() throws SQLException { - final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:" + "postgresqlUser"; - IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, - new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - - testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); - } - - @Test - public void testConnectWithSpecifiedIamDefaultPort() throws SQLException { - final String iamDefaultPort = "9999"; - props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, iamDefaultPort); - final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" - + iamDefaultPort + ":postgresqlUser"; - IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, - new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - - testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); - } - - @Test - public void testConnectWithSpecifiedRegion() throws SQLException { - final String cacheKeyWithNewRegion = - "us-west-1:pg.testdb.us-west-1.rds.amazonaws.com:" + DEFAULT_PG_PORT + ":" + "postgresqlUser"; - props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-west-1"); - IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewRegion, - new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); - - testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_REGION); - } - - @Test - public void testConnectWithSpecifiedHost() throws SQLException { - props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-east-2"); - props.setProperty(IamAuthConnectionPlugin.IAM_HOST.name, "pg.testdb.us-east-2.rds.amazonaws.com"); - - when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); - - testGenerateToken( - PG_DRIVER_PROTOCOL, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("8.8.8.8").build(), - "pg.testdb.us-east-2.rds.amazonaws.com"); - } - - @Test - public void testAwsSupportedRegionsUrlExists() throws IOException { - final URL url = - new URL("https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html"); - final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection(); - final int responseCode = urlConnection.getResponseCode(); - - assertEquals(HttpURLConnection.HTTP_OK, responseCode); - } - - public void testTokenSetInProps(final String protocol, final HostSpec hostSpec) throws SQLException { - - IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); - doThrow(new SQLException()).when(mockLambda).call(); - - assertThrows(SQLException.class, () -> targetPlugin.connect(protocol, hostSpec, props, true, mockLambda)); - verify(mockLambda, times(1)).call(); - - assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - } - - private void testGenerateToken(final String protocol, final HostSpec hostSpec) throws SQLException { - testGenerateToken(protocol, hostSpec, hostSpec.getHost()); - } - - private void testGenerateToken( - final String protocol, - final HostSpec hostSpec, - final String expectedHost) throws SQLException { - final IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); - final IamAuthConnectionPlugin spyPlugin = Mockito.spy(targetPlugin); - - doThrow(new SQLException()).when(mockLambda).call(); - - assertThrows(SQLException.class, - () -> spyPlugin.connect(protocol, hostSpec, props, true, mockLambda)); - - verify(mockIamTokenUtils).generateAuthenticationToken( - any(DefaultCredentialsProvider.class), - eq(Region.US_EAST_2), - eq(expectedHost), - eq(DEFAULT_PG_PORT), - eq("postgresqlUser")); - verify(mockLambda, times(1)).call(); - - assertEquals(GENERATED_TOKEN, PropertyDefinition.PASSWORD.getString(props)); - assertEquals(GENERATED_TOKEN, IamAuthCacheHolder.tokenCache.get(PG_CACHE_KEY).getToken()); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.iam; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.anyInt; +// import static org.mockito.ArgumentMatchers.anyString; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doThrow; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import java.io.IOException; +// import java.net.HttpURLConnection; +// import java.net.URL; +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.time.Instant; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeAll; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.Mockito; +// import org.mockito.MockitoAnnotations; +// import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +// import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +// import software.amazon.awssdk.regions.Region; +// import software.amazon.jdbc.Driver; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.plugin.TokenInfo; +// import software.amazon.jdbc.util.RdsUtils; +// import software.amazon.jdbc.util.telemetry.TelemetryContext; +// import software.amazon.jdbc.util.telemetry.TelemetryCounter; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; +// +// class IamAuthConnectionPluginTest { +// +// private static final String GENERATED_TOKEN = "generatedToken"; +// private static final String TEST_TOKEN = "testToken"; +// private static final int DEFAULT_PG_PORT = 5432; +// private static final int DEFAULT_MYSQL_PORT = 3306; +// private static final String PG_CACHE_KEY = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" +// + DEFAULT_PG_PORT + ":postgresqlUser"; +// private static final String MYSQL_CACHE_KEY = "us-east-2:mysql.testdb.us-east-2.rds.amazonaws.com:" +// + DEFAULT_MYSQL_PORT + ":mysqlUser"; +// private static final String PG_DRIVER_PROTOCOL = "jdbc:postgresql:"; +// private static final String MYSQL_DRIVER_PROTOCOL = "jdbc:mysql:"; +// private static final HostSpec PG_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); +// private static final HostSpec PG_HOST_SPEC_WITH_PORT = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("pg.testdb.us-east-2.rds.amazonaws.com").port(1234).build(); +// private static final HostSpec PG_HOST_SPEC_WITH_REGION = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("pg.testdb.us-west-1.rds.amazonaws.com").build(); +// private static final HostSpec MYSQL_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("mysql.testdb.us-east-2.rds.amazonaws.com").build(); +// private Properties props; +// +// @Mock PluginService mockPluginService; +// @Mock TelemetryFactory mockTelemetryFactory; +// @Mock TelemetryCounter mockTelemetryCounter; +// @Mock TelemetryContext mockTelemetryContext; +// @Mock JdbcCallable mockLambda; +// @Mock Dialect mockDialect; +// @Mock private RdsUtils mockRdsUtils; +// @Mock private IamTokenUtility mockIamTokenUtils; +// private AutoCloseable closable; +// +// @BeforeEach +// public void init() { +// closable = MockitoAnnotations.openMocks(this); +// props = new Properties(); +// props.setProperty(PropertyDefinition.USER.name, "postgresqlUser"); +// props.setProperty(PropertyDefinition.PASSWORD.name, "postgresqlPassword"); +// props.setProperty(PropertyDefinition.PLUGINS.name, "iam"); +// IamAuthConnectionPlugin.clearCache(); +// +// when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); +// when(mockIamTokenUtils.generateAuthenticationToken( +// any(AwsCredentialsProvider.class), +// any(Region.class), +// anyString(), +// anyInt(), +// anyString())).thenReturn(GENERATED_TOKEN); +// when(mockPluginService.getDialect()).thenReturn(mockDialect); +// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); +// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); +// when(mockTelemetryFactory.openTelemetryContext(anyString(), eq(TelemetryTraceLevel.NESTED))).thenReturn( +// mockTelemetryContext); +// } +// +// @AfterEach +// public void cleanUp() throws Exception { +// closable.close(); +// } +// +// @BeforeAll +// public static void registerDrivers() throws SQLException { +// if (!org.postgresql.Driver.isRegistered()) { +// org.postgresql.Driver.register(); +// } +// +// if (!Driver.isRegistered()) { +// Driver.register(); +// } +// } +// +// @Test +// public void testPostgresConnectValidTokenInCache() throws SQLException { +// IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, +// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); +// +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); +// +// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); +// } +// +// @Test +// public void testMySqlConnectValidTokenInCache() throws SQLException { +// props.setProperty(PropertyDefinition.USER.name, "mysqlUser"); +// props.setProperty(PropertyDefinition.PASSWORD.name, "mysqlPassword"); +// IamAuthCacheHolder.tokenCache.put(MYSQL_CACHE_KEY, +// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); +// +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_MYSQL_PORT); +// +// testTokenSetInProps(MYSQL_DRIVER_PROTOCOL, MYSQL_HOST_SPEC); +// } +// +// @Test +// public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLException { +// final String invalidIamDefaultPort = "0"; +// props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); +// +// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" +// + PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser"; +// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, +// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); +// +// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); +// } +// +// @Test +// public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() throws SQLException { +// final String invalidIamDefaultPort = "0"; +// props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); +// +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); +// +// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" +// + DEFAULT_PG_PORT + ":postgresqlUser"; +// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, +// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); +// +// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); +// } +// +// @Test +// public void testConnectExpiredTokenInCache() throws SQLException { +// IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, +// new TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000))); +// +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); +// +// testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); +// } +// +// @Test +// public void testConnectEmptyCache() throws SQLException { +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); +// +// testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); +// } +// +// @Test +// public void testConnectWithSpecifiedPort() throws SQLException { +// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:" + "postgresqlUser"; +// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, +// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); +// +// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); +// } +// +// @Test +// public void testConnectWithSpecifiedIamDefaultPort() throws SQLException { +// final String iamDefaultPort = "9999"; +// props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, iamDefaultPort); +// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" +// + iamDefaultPort + ":postgresqlUser"; +// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, +// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); +// +// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); +// } +// +// @Test +// public void testConnectWithSpecifiedRegion() throws SQLException { +// final String cacheKeyWithNewRegion = +// "us-west-1:pg.testdb.us-west-1.rds.amazonaws.com:" + DEFAULT_PG_PORT + ":" + "postgresqlUser"; +// props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-west-1"); +// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewRegion, +// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); +// +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); +// +// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_REGION); +// } +// +// @Test +// public void testConnectWithSpecifiedHost() throws SQLException { +// props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-east-2"); +// props.setProperty(IamAuthConnectionPlugin.IAM_HOST.name, "pg.testdb.us-east-2.rds.amazonaws.com"); +// +// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); +// +// testGenerateToken( +// PG_DRIVER_PROTOCOL, +// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("8.8.8.8").build(), +// "pg.testdb.us-east-2.rds.amazonaws.com"); +// } +// +// @Test +// public void testAwsSupportedRegionsUrlExists() throws IOException { +// final URL url = +// new URL("https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html"); +// final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection(); +// final int responseCode = urlConnection.getResponseCode(); +// +// assertEquals(HttpURLConnection.HTTP_OK, responseCode); +// } +// +// public void testTokenSetInProps(final String protocol, final HostSpec hostSpec) throws SQLException { +// +// IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); +// doThrow(new SQLException()).when(mockLambda).call(); +// +// assertThrows(SQLException.class, () -> targetPlugin.connect(protocol, hostSpec, props, true, mockLambda)); +// verify(mockLambda, times(1)).call(); +// +// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// } +// +// private void testGenerateToken(final String protocol, final HostSpec hostSpec) throws SQLException { +// testGenerateToken(protocol, hostSpec, hostSpec.getHost()); +// } +// +// private void testGenerateToken( +// final String protocol, +// final HostSpec hostSpec, +// final String expectedHost) throws SQLException { +// final IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); +// final IamAuthConnectionPlugin spyPlugin = Mockito.spy(targetPlugin); +// +// doThrow(new SQLException()).when(mockLambda).call(); +// +// assertThrows(SQLException.class, +// () -> spyPlugin.connect(protocol, hostSpec, props, true, mockLambda)); +// +// verify(mockIamTokenUtils).generateAuthenticationToken( +// any(DefaultCredentialsProvider.class), +// eq(Region.US_EAST_2), +// eq(expectedHost), +// eq(DEFAULT_PG_PORT), +// eq("postgresqlUser")); +// verify(mockLambda, times(1)).call(); +// +// assertEquals(GENERATED_TOKEN, PropertyDefinition.PASSWORD.getString(props)); +// assertEquals(GENERATED_TOKEN, IamAuthCacheHolder.tokenCache.get(PG_CACHE_KEY).getToken()); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java index 411233100..3aa1b40de 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java @@ -1,162 +1,162 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.limitless; - -import static org.junit.Assert.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static software.amazon.jdbc.plugin.limitless.LimitlessConnectionPlugin.INTERVAL_MILLIS; - -import java.sql.Connection; -import java.sql.SQLException; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import software.amazon.jdbc.HostListProvider; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.AuroraPgDialect; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.dialect.PgDialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; - -public class LimitlessConnectionPluginTest { - - private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; - private static final HostSpec INPUT_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); - private static final String CLUSTER_ID = "someClusterId"; - - private static final HostSpec expectedSelectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("expected-selected-instance").role(HostRole.WRITER).weight(Long.MAX_VALUE).build(); - private static final Dialect supportedDialect = new AuroraPgDialect(); - @Mock JdbcCallable mockConnectFuncLambda; - @Mock private Connection mockConnection; - @Mock private PluginService mockPluginService; - @Mock private HostListProvider mockHostListProvider; - @Mock private LimitlessRouterService mockLimitlessRouterService; - private static Properties props; - - private static LimitlessConnectionPlugin plugin; - - private AutoCloseable closeable; - - @BeforeEach - public void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - props = new Properties(); - plugin = new LimitlessConnectionPlugin(mockPluginService, props, () -> mockLimitlessRouterService); - - when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); - when(mockPluginService.getDialect()).thenReturn(supportedDialect); - when(mockHostListProvider.getClusterId()).thenReturn(CLUSTER_ID); - when(mockConnectFuncLambda.call()).thenReturn(mockConnection); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - } - - @Test - void testConnect() throws SQLException { - doAnswer(new Answer() { - public Void answer(InvocationOnMock invocation) { - LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; - context.setConnection(mockConnection); - return null; - } - }).when(mockLimitlessRouterService).establishConnection(any()); - - final Connection expectedConnection = mockConnection; - final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, - mockConnectFuncLambda); - - assertEquals(expectedConnection, actualConnection); - verify(mockPluginService, times(1)).getDialect(); - verify(mockConnectFuncLambda, times(0)).call(); - verify(mockLimitlessRouterService, times(1)) - .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); - verify(mockLimitlessRouterService, times(1)).establishConnection(any()); - } - - @Test - void testConnectGivenNullConnection() throws SQLException { - doAnswer(new Answer() { - public Void answer(InvocationOnMock invocation) { - LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; - context.setConnection(null); - return null; - } - }).when(mockLimitlessRouterService).establishConnection(any()); - - assertThrows( - SQLException.class, - () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); - - verify(mockPluginService, times(1)).getDialect(); - verify(mockConnectFuncLambda, times(0)).call(); - verify(mockLimitlessRouterService, times(1)) - .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); - verify(mockLimitlessRouterService, times(1)).establishConnection(any()); - } - - @Test - void testConnectGivenUnsupportedDialect() throws SQLException { - final Dialect unsupportedDialect = new PgDialect(); - when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, unsupportedDialect); - - assertThrows( - UnsupportedOperationException.class, - () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); - - verify(mockPluginService, times(2)).getDialect(); - verify(mockConnectFuncLambda, times(1)).call(); - verify(mockLimitlessRouterService, times(0)) - .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); - verify(mockLimitlessRouterService, times(0)).establishConnection(any()); - } - - @Test - void testConnectGivenSupportedDialectAfterRefresh() throws SQLException { - final Dialect unsupportedDialect = new PgDialect(); - when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, supportedDialect); - - final Connection expectedConnection = mockConnection; - final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, - mockConnectFuncLambda); - - assertEquals(expectedConnection, actualConnection); - verify(mockPluginService, times(2)).getDialect(); - verify(mockConnectFuncLambda, times(1)).call(); - verify(mockLimitlessRouterService, times(1)) - .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); - verify(mockLimitlessRouterService, times(1)).establishConnection(any()); - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.limitless; +// +// import static org.junit.Assert.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.Mockito.doAnswer; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// import static software.amazon.jdbc.plugin.limitless.LimitlessConnectionPlugin.INTERVAL_MILLIS; +// +// import java.sql.Connection; +// import java.sql.SQLException; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import org.mockito.invocation.InvocationOnMock; +// import org.mockito.stubbing.Answer; +// import software.amazon.jdbc.HostListProvider; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.dialect.AuroraPgDialect; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.dialect.PgDialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// +// public class LimitlessConnectionPluginTest { +// +// private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; +// private static final HostSpec INPUT_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); +// private static final String CLUSTER_ID = "someClusterId"; +// +// private static final HostSpec expectedSelectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("expected-selected-instance").role(HostRole.WRITER).weight(Long.MAX_VALUE).build(); +// private static final Dialect supportedDialect = new AuroraPgDialect(); +// @Mock JdbcCallable mockConnectFuncLambda; +// @Mock private Connection mockConnection; +// @Mock private PluginService mockPluginService; +// @Mock private HostListProvider mockHostListProvider; +// @Mock private LimitlessRouterService mockLimitlessRouterService; +// private static Properties props; +// +// private static LimitlessConnectionPlugin plugin; +// +// private AutoCloseable closeable; +// +// @BeforeEach +// public void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// props = new Properties(); +// plugin = new LimitlessConnectionPlugin(mockPluginService, props, () -> mockLimitlessRouterService); +// +// when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); +// when(mockPluginService.getDialect()).thenReturn(supportedDialect); +// when(mockHostListProvider.getClusterId()).thenReturn(CLUSTER_ID); +// when(mockConnectFuncLambda.call()).thenReturn(mockConnection); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// } +// +// @Test +// void testConnect() throws SQLException { +// doAnswer(new Answer() { +// public Void answer(InvocationOnMock invocation) { +// LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; +// context.setConnection(mockConnection); +// return null; +// } +// }).when(mockLimitlessRouterService).establishConnection(any()); +// +// final Connection expectedConnection = mockConnection; +// final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, +// mockConnectFuncLambda); +// +// assertEquals(expectedConnection, actualConnection); +// verify(mockPluginService, times(1)).getDialect(); +// verify(mockConnectFuncLambda, times(0)).call(); +// verify(mockLimitlessRouterService, times(1)) +// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); +// verify(mockLimitlessRouterService, times(1)).establishConnection(any()); +// } +// +// @Test +// void testConnectGivenNullConnection() throws SQLException { +// doAnswer(new Answer() { +// public Void answer(InvocationOnMock invocation) { +// LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; +// context.setConnection(null); +// return null; +// } +// }).when(mockLimitlessRouterService).establishConnection(any()); +// +// assertThrows( +// SQLException.class, +// () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); +// +// verify(mockPluginService, times(1)).getDialect(); +// verify(mockConnectFuncLambda, times(0)).call(); +// verify(mockLimitlessRouterService, times(1)) +// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); +// verify(mockLimitlessRouterService, times(1)).establishConnection(any()); +// } +// +// @Test +// void testConnectGivenUnsupportedDialect() throws SQLException { +// final Dialect unsupportedDialect = new PgDialect(); +// when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, unsupportedDialect); +// +// assertThrows( +// UnsupportedOperationException.class, +// () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); +// +// verify(mockPluginService, times(2)).getDialect(); +// verify(mockConnectFuncLambda, times(1)).call(); +// verify(mockLimitlessRouterService, times(0)) +// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); +// verify(mockLimitlessRouterService, times(0)).establishConnection(any()); +// } +// +// @Test +// void testConnectGivenSupportedDialectAfterRefresh() throws SQLException { +// final Dialect unsupportedDialect = new PgDialect(); +// when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, supportedDialect); +// +// final Connection expectedConnection = mockConnection; +// final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, +// mockConnectFuncLambda); +// +// assertEquals(expectedConnection, actualConnection); +// verify(mockPluginService, times(2)).getDialect(); +// verify(mockConnectFuncLambda, times(1)).call(); +// verify(mockLimitlessRouterService, times(1)) +// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); +// verify(mockLimitlessRouterService, times(1)).establishConnection(any()); +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index c7c7bdc1b..1dcb8e62c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -1,626 +1,626 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.plugin.readwritesplitting; - -import static org.junit.Assert.assertThrows; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.mockito.AdditionalMatchers.not; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.zaxxer.hikari.HikariConfig; -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Arrays; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Properties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.HostListProviderService; -import software.amazon.jdbc.HostRole; -import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.HostSpecBuilder; -import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.NodeChangeOptions; -import software.amazon.jdbc.OldConnectionSuggestedAction; -import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; -import software.amazon.jdbc.util.SqlState; - -public class ReadWriteSplittingPluginTest { - private static final String TEST_PROTOCOL = "jdbc:postgresql:"; - private static final int TEST_PORT = 5432; - private static final Properties defaultProps = new Properties(); - - private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-0").port(TEST_PORT).build(); - private final HostSpec readerHostSpec1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-1").port(TEST_PORT).role(HostRole.READER).build(); - private final HostSpec readerHostSpec2 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-2").port(TEST_PORT).role(HostRole.READER).build(); - private final HostSpec readerHostSpec3 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-3").port(TEST_PORT).role(HostRole.READER).build(); - private final HostSpec readerHostSpecWithIncorrectRole = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("instance-1").port(TEST_PORT).role(HostRole.WRITER).build(); - private final HostSpec instanceUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("jdbc:aws-wrapper:postgresql://my-instance-name.XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT) - .build(); - private final HostSpec ipUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("10.10.10.10").port(TEST_PORT).build(); - private final HostSpec clusterUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT).build(); - private final List defaultHosts = Arrays.asList( - writerHostSpec, - readerHostSpec1, - readerHostSpec2, - readerHostSpec3); - private final List singleReaderTopology = Arrays.asList( - writerHostSpec, - readerHostSpec1); - - private AutoCloseable closeable; - - @Mock private JdbcCallable mockConnectFunc; - @Mock private JdbcCallable mockSqlFunction; - @Mock private PluginService mockPluginService; - @Mock private Dialect mockDialect; - @Mock private HostListProviderService mockHostListProviderService; - @Mock private Connection mockWriterConn; - @Mock private Connection mockNewWriterConn; - @Mock private Connection mockClosedWriterConn; - @Mock private Connection mockReaderConn1; - @Mock private Connection mockReaderConn2; - @Mock private Connection mockReaderConn3; - @Mock private Statement mockStatement; - @Mock private ResultSet mockResultSet; - @Mock private EnumSet mockChanges; - - @BeforeEach - public void init() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - mockDefaultBehavior(); - } - - @AfterEach - void cleanUp() throws Exception { - closeable.close(); - defaultProps.clear(); - } - - void mockDefaultBehavior() throws SQLException { - when(this.mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); - when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); - when(this.mockPluginService.getAllHosts()).thenReturn(defaultHosts); - when(this.mockPluginService.getHosts()).thenReturn(defaultHosts); - when(this.mockPluginService.getHostSpecByStrategy(eq(HostRole.READER), eq("random"))) - .thenReturn(readerHostSpec1); - when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class))) - .thenReturn(mockWriterConn); - when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class), any())) - .thenReturn(mockWriterConn); - when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(writerHostSpec); - when(this.mockPluginService.getHostRole(mockWriterConn)).thenReturn(HostRole.WRITER); - when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(HostRole.READER); - when(this.mockPluginService.getHostRole(mockReaderConn2)).thenReturn(HostRole.READER); - when(this.mockPluginService.getHostRole(mockReaderConn3)).thenReturn(HostRole.READER); - when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class))) - .thenReturn(mockReaderConn1); - when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class), any())) - .thenReturn(mockReaderConn1); - when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class))) - .thenReturn(mockReaderConn2); - when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class), any())) - .thenReturn(mockReaderConn2); - when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class))) - .thenReturn(mockReaderConn3); - when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class), any())) - .thenReturn(mockReaderConn3); - when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); - when(this.mockConnectFunc.call()).thenReturn(mockWriterConn); - when(mockWriterConn.createStatement()).thenReturn(mockStatement); - when(mockReaderConn1.createStatement()).thenReturn(mockStatement); - when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true); - when(mockClosedWriterConn.isClosed()).thenReturn(true); - } - - @Test - public void testSetReadOnly_trueFalse() throws SQLException { - when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); - when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - null); - plugin.switchConnectionIfRequired(true); - - verify(mockPluginService, times(1)) - .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); - verify(mockPluginService, times(0)) - .setCurrentConnection(eq(mockWriterConn), any(HostSpec.class)); - assertEquals(mockReaderConn1, plugin.getReaderConnection()); - assertEquals(mockWriterConn, plugin.getWriterConnection()); - - when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); - when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); - - plugin.switchConnectionIfRequired(false); - - verify(mockPluginService, times(1)) - .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); - verify(mockPluginService, times(1)) - .setCurrentConnection(eq(mockWriterConn), eq(writerHostSpec)); - assertEquals(mockReaderConn1, plugin.getReaderConnection()); - assertEquals(mockWriterConn, plugin.getWriterConnection()); - } - - @Test - public void testSetReadOnlyTrue_alreadyOnReader() throws SQLException { - when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); - when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); - when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - mockReaderConn1); - - plugin.switchConnectionIfRequired(true); - - verify(mockPluginService, times(0)) - .setCurrentConnection(any(Connection.class), any(HostSpec.class)); - assertEquals(mockReaderConn1, plugin.getReaderConnection()); - assertNull(plugin.getWriterConnection()); - } - - @Test - public void testSetReadOnlyFalse_alreadyOnWriter() throws SQLException { - when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); - when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); - when(mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - null); - plugin.switchConnectionIfRequired(false); - - verify(mockPluginService, times(0)) - .setCurrentConnection(any(Connection.class), any(HostSpec.class)); - assertEquals(mockWriterConn, plugin.getWriterConnection()); - assertNull(plugin.getReaderConnection()); - } - - @Test - public void testSetReadOnly_falseInTransaction() { - when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); - when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); - when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); - when(mockPluginService.isInTransaction()).thenReturn(true); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - mockReaderConn1); - - final SQLException e = - assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); - assertEquals(SqlState.ACTIVE_SQL_TRANSACTION.getState(), e.getSQLState()); - } - - @Test - public void testSetReadOnly_true() throws SQLException { - final ReadWriteSplittingPlugin plugin = - new ReadWriteSplittingPlugin(mockPluginService, defaultProps); - plugin.switchConnectionIfRequired(true); - - assertEquals(mockReaderConn1, plugin.getReaderConnection()); - } - - @Test - public void testSetReadOnly_false() throws SQLException { - when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); - when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - mockReaderConn1); - plugin.switchConnectionIfRequired(false); - - assertEquals(mockWriterConn, plugin.getWriterConnection()); - } - - @Test - public void testSetReadOnly_true_oneHost() throws SQLException { - when(this.mockPluginService.getHosts()).thenReturn(Collections.singletonList(writerHostSpec)); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - null); - plugin.switchConnectionIfRequired(true); - - verify(mockPluginService, times(0)) - .setCurrentConnection(any(Connection.class), any(HostSpec.class)); - assertEquals(mockWriterConn, plugin.getWriterConnection()); - assertNull(plugin.getReaderConnection()); - } - - @Test - public void testSetReadOnly_false_writerConnectionFails() throws SQLException { - when(mockPluginService.connect(eq(writerHostSpec), eq(defaultProps), any())) - .thenThrow(SQLException.class); - when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); - when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); - when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - mockReaderConn1); - - final SQLException e = - assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); - assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), e.getSQLState()); - verify(mockPluginService, times(0)) - .setCurrentConnection(any(Connection.class), any(HostSpec.class)); - } - - @Test - public void testSetReadOnly_true_readerConnectionFailed() throws SQLException { - when(this.mockPluginService.connect(eq(readerHostSpec1), eq(defaultProps), any())) - .thenThrow(SQLException.class); - when(this.mockPluginService.connect(eq(readerHostSpec2), eq(defaultProps), any())) - .thenThrow(SQLException.class); - when(this.mockPluginService.connect(eq(readerHostSpec3), eq(defaultProps), any())) - .thenThrow(SQLException.class); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - null); - plugin.switchConnectionIfRequired(true); - - verify(mockPluginService, times(0)) - .setCurrentConnection(any(Connection.class), any(HostSpec.class)); - assertNull(plugin.getReaderConnection()); - } - - @Test - public void testSetReadOnlyOnClosedConnection() throws SQLException { - when(mockPluginService.getCurrentConnection()).thenReturn(mockClosedWriterConn); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockClosedWriterConn, - null); - - final SQLException e = - assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(true)); - assertEquals(SqlState.CONNECTION_NOT_OPEN.getState(), e.getSQLState()); - verify(mockPluginService, times(0)) - .setCurrentConnection(any(Connection.class), any(HostSpec.class)); - assertNull(plugin.getReaderConnection()); - } - - @Test - public void testExecute_failoverToNewWriter() throws SQLException { - when(mockSqlFunction.call()).thenThrow(FailoverSuccessSQLException.class); - when(mockPluginService.getCurrentConnection()).thenReturn(mockNewWriterConn); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - null); - - assertThrows( - SQLException.class, - () -> plugin.execute( - ResultSet.class, - SQLException.class, - mockStatement, - "Statement.executeQuery", - mockSqlFunction, - new Object[] { - "begin"})); - verify(mockWriterConn, times(1)).close(); - } - - @Test - public void testNotifyConnectionChange() { - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - null); - - final OldConnectionSuggestedAction suggestion = plugin.notifyConnectionChanged(mockChanges); - - assertEquals(mockWriterConn, plugin.getWriterConnection()); - assertEquals(OldConnectionSuggestedAction.NO_OPINION, suggestion); - } - - @Test - public void testConnectNonInitialConnection() throws SQLException { - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - null); - - final Connection connection = - plugin.connect(TEST_PROTOCOL, writerHostSpec, defaultProps, false, this.mockConnectFunc); - - assertEquals(mockWriterConn, connection); - verify(mockConnectFunc).call(); - verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); - } - - @Test - public void testConnectRdsInstanceUrl() throws SQLException { - when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); - when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - null); - final Connection connection = plugin.connect( - TEST_PROTOCOL, - instanceUrlHostSpec, - defaultProps, - true, - this.mockConnectFunc); - - assertEquals(mockReaderConn1, connection); - verify(mockConnectFunc).call(); - verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); - } - - @Test - public void testConnectReaderIpUrl() throws SQLException { - when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); - when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - null); - final Connection connection = - plugin.connect(TEST_PROTOCOL, ipUrlHostSpec, defaultProps, true, this.mockConnectFunc); - - assertEquals(mockReaderConn1, connection); - verify(mockConnectFunc).call(); - verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); - } - - @Test - public void testConnectClusterUrl() throws SQLException { - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - null); - final Connection connection = - plugin.connect(TEST_PROTOCOL, clusterUrlHostSpec, defaultProps, true, this.mockConnectFunc); - - assertEquals(mockWriterConn, connection); - verify(mockConnectFunc).call(); - verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); - } - - @Test - public void testConnect_errorUpdatingHostSpec() throws SQLException { - when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); - when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(null); - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - null); - - assertThrows( - SQLException.class, - () -> plugin.connect( - TEST_PROTOCOL, - ipUrlHostSpec, - defaultProps, - true, - this.mockConnectFunc)); - verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); - } - - @Test - public void testExecuteClearWarnings() throws SQLException { - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - mockReaderConn1); - - plugin.execute( - ResultSet.class, - SQLException.class, - mockStatement, - "Connection.clearWarnings", - mockSqlFunction, - new Object[] {} - ); - verify(mockWriterConn, times(1)).clearWarnings(); - verify(mockReaderConn1, times(1)).clearWarnings(); - } - - @Test - public void testExecuteClearWarningsOnClosedConnectionsIsNotCalled() throws SQLException { - when(mockWriterConn.isClosed()).thenReturn(true); - when(mockReaderConn1.isClosed()).thenReturn(true); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - mockReaderConn1); - - plugin.execute( - ResultSet.class, - SQLException.class, - mockStatement, - "Connection.clearWarnings", - mockSqlFunction, - new Object[] {} - ); - verify(mockWriterConn, never()).clearWarnings(); - verify(mockReaderConn1, never()).clearWarnings(); - } - - @Test - public void testExecuteClearWarningsOnNullConnectionsIsNotCalled() throws SQLException { - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - null); - - // calling clearWarnings() on nullified connection would throw an exception - assertDoesNotThrow(() -> { - plugin.execute( - ResultSet.class, - SQLException.class, - mockStatement, - "Connection.clearWarnings", - mockSqlFunction, - new Object[] {} - ); - }); - } - - @Test - public void testClosePooledReaderConnectionAfterSetReadOnly() throws SQLException { - doReturn(writerHostSpec) - .doReturn(writerHostSpec) - .doReturn(readerHostSpec1) - .when(this.mockPluginService).getCurrentHostSpec(); - doReturn(mockReaderConn1).when(mockPluginService).connect(readerHostSpec1, null); - when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); - when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - mockWriterConn, - null); - final ReadWriteSplittingPlugin spyPlugin = spy(plugin); - - spyPlugin.switchConnectionIfRequired(true); - spyPlugin.switchConnectionIfRequired(false); - - verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockReaderConn1)); - } - - @Test - public void testClosePooledWriterConnectionAfterSetReadOnly() throws SQLException { - doReturn(writerHostSpec) - .doReturn(writerHostSpec) - .doReturn(readerHostSpec1) - .doReturn(readerHostSpec1) - .doReturn(writerHostSpec) - .when(this.mockPluginService).getCurrentHostSpec(); - doReturn(mockWriterConn).when(mockPluginService).connect(writerHostSpec, null); - when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); - when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); - - final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( - mockPluginService, - defaultProps, - mockHostListProviderService, - null, - null); - final ReadWriteSplittingPlugin spyPlugin = spy(plugin); - - spyPlugin.switchConnectionIfRequired(true); - spyPlugin.switchConnectionIfRequired(false); - spyPlugin.switchConnectionIfRequired(true); - - verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockWriterConn)); - } - - private static HikariConfig getHikariConfig(HostSpec hostSpec, Properties props) { - final HikariConfig config = new HikariConfig(); - config.setMaximumPoolSize(3); - config.setInitializationFailTimeout(75000); - config.setConnectionTimeout(10000); - return config; - } - - private static String getPoolKey(HostSpec hostSpec, Properties props) { - final String user = props.getProperty(PropertyDefinition.USER.name); - final String somePropertyValue = props.getProperty("somePropertyValue"); - return hostSpec.getUrl() + user + somePropertyValue; - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.plugin.readwritesplitting; +// +// import static org.junit.Assert.assertThrows; +// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.mockito.AdditionalMatchers.not; +// import static org.mockito.ArgumentMatchers.any; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.never; +// import static org.mockito.Mockito.spy; +// import static org.mockito.Mockito.times; +// import static org.mockito.Mockito.verify; +// import static org.mockito.Mockito.when; +// +// import com.zaxxer.hikari.HikariConfig; +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.SQLException; +// import java.sql.Statement; +// import java.util.Arrays; +// import java.util.Collections; +// import java.util.EnumSet; +// import java.util.List; +// import java.util.Properties; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.HostListProviderService; +// import software.amazon.jdbc.HostRole; +// import software.amazon.jdbc.HostSpec; +// import software.amazon.jdbc.HostSpecBuilder; +// import software.amazon.jdbc.JdbcCallable; +// import software.amazon.jdbc.NodeChangeOptions; +// import software.amazon.jdbc.OldConnectionSuggestedAction; +// import software.amazon.jdbc.PluginService; +// import software.amazon.jdbc.PropertyDefinition; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +// import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; +// import software.amazon.jdbc.util.SqlState; +// +// public class ReadWriteSplittingPluginTest { +// private static final String TEST_PROTOCOL = "jdbc:postgresql:"; +// private static final int TEST_PORT = 5432; +// private static final Properties defaultProps = new Properties(); +// +// private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-0").port(TEST_PORT).build(); +// private final HostSpec readerHostSpec1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-1").port(TEST_PORT).role(HostRole.READER).build(); +// private final HostSpec readerHostSpec2 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-2").port(TEST_PORT).role(HostRole.READER).build(); +// private final HostSpec readerHostSpec3 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-3").port(TEST_PORT).role(HostRole.READER).build(); +// private final HostSpec readerHostSpecWithIncorrectRole = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("instance-1").port(TEST_PORT).role(HostRole.WRITER).build(); +// private final HostSpec instanceUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("jdbc:aws-wrapper:postgresql://my-instance-name.XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT) +// .build(); +// private final HostSpec ipUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("10.10.10.10").port(TEST_PORT).build(); +// private final HostSpec clusterUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) +// .host("my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT).build(); +// private final List defaultHosts = Arrays.asList( +// writerHostSpec, +// readerHostSpec1, +// readerHostSpec2, +// readerHostSpec3); +// private final List singleReaderTopology = Arrays.asList( +// writerHostSpec, +// readerHostSpec1); +// +// private AutoCloseable closeable; +// +// @Mock private JdbcCallable mockConnectFunc; +// @Mock private JdbcCallable mockSqlFunction; +// @Mock private PluginService mockPluginService; +// @Mock private Dialect mockDialect; +// @Mock private HostListProviderService mockHostListProviderService; +// @Mock private Connection mockWriterConn; +// @Mock private Connection mockNewWriterConn; +// @Mock private Connection mockClosedWriterConn; +// @Mock private Connection mockReaderConn1; +// @Mock private Connection mockReaderConn2; +// @Mock private Connection mockReaderConn3; +// @Mock private Statement mockStatement; +// @Mock private ResultSet mockResultSet; +// @Mock private EnumSet mockChanges; +// +// @BeforeEach +// public void init() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// mockDefaultBehavior(); +// } +// +// @AfterEach +// void cleanUp() throws Exception { +// closeable.close(); +// defaultProps.clear(); +// } +// +// void mockDefaultBehavior() throws SQLException { +// when(this.mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); +// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); +// when(this.mockPluginService.getAllHosts()).thenReturn(defaultHosts); +// when(this.mockPluginService.getHosts()).thenReturn(defaultHosts); +// when(this.mockPluginService.getHostSpecByStrategy(eq(HostRole.READER), eq("random"))) +// .thenReturn(readerHostSpec1); +// when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class))) +// .thenReturn(mockWriterConn); +// when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class), any())) +// .thenReturn(mockWriterConn); +// when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(writerHostSpec); +// when(this.mockPluginService.getHostRole(mockWriterConn)).thenReturn(HostRole.WRITER); +// when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(HostRole.READER); +// when(this.mockPluginService.getHostRole(mockReaderConn2)).thenReturn(HostRole.READER); +// when(this.mockPluginService.getHostRole(mockReaderConn3)).thenReturn(HostRole.READER); +// when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class))) +// .thenReturn(mockReaderConn1); +// when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class), any())) +// .thenReturn(mockReaderConn1); +// when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class))) +// .thenReturn(mockReaderConn2); +// when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class), any())) +// .thenReturn(mockReaderConn2); +// when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class))) +// .thenReturn(mockReaderConn3); +// when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class), any())) +// .thenReturn(mockReaderConn3); +// when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); +// when(this.mockConnectFunc.call()).thenReturn(mockWriterConn); +// when(mockWriterConn.createStatement()).thenReturn(mockStatement); +// when(mockReaderConn1.createStatement()).thenReturn(mockStatement); +// when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); +// when(mockResultSet.next()).thenReturn(true); +// when(mockClosedWriterConn.isClosed()).thenReturn(true); +// } +// +// @Test +// public void testSetReadOnly_trueFalse() throws SQLException { +// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// null); +// plugin.switchConnectionIfRequired(true); +// +// verify(mockPluginService, times(1)) +// .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); +// verify(mockPluginService, times(0)) +// .setCurrentConnection(eq(mockWriterConn), any(HostSpec.class)); +// assertEquals(mockReaderConn1, plugin.getReaderConnection()); +// assertEquals(mockWriterConn, plugin.getWriterConnection()); +// +// when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); +// +// plugin.switchConnectionIfRequired(false); +// +// verify(mockPluginService, times(1)) +// .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); +// verify(mockPluginService, times(1)) +// .setCurrentConnection(eq(mockWriterConn), eq(writerHostSpec)); +// assertEquals(mockReaderConn1, plugin.getReaderConnection()); +// assertEquals(mockWriterConn, plugin.getWriterConnection()); +// } +// +// @Test +// public void testSetReadOnlyTrue_alreadyOnReader() throws SQLException { +// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// mockReaderConn1); +// +// plugin.switchConnectionIfRequired(true); +// +// verify(mockPluginService, times(0)) +// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); +// assertEquals(mockReaderConn1, plugin.getReaderConnection()); +// assertNull(plugin.getWriterConnection()); +// } +// +// @Test +// public void testSetReadOnlyFalse_alreadyOnWriter() throws SQLException { +// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// null); +// plugin.switchConnectionIfRequired(false); +// +// verify(mockPluginService, times(0)) +// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); +// assertEquals(mockWriterConn, plugin.getWriterConnection()); +// assertNull(plugin.getReaderConnection()); +// } +// +// @Test +// public void testSetReadOnly_falseInTransaction() { +// when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); +// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); +// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); +// when(mockPluginService.isInTransaction()).thenReturn(true); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// mockReaderConn1); +// +// final SQLException e = +// assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); +// assertEquals(SqlState.ACTIVE_SQL_TRANSACTION.getState(), e.getSQLState()); +// } +// +// @Test +// public void testSetReadOnly_true() throws SQLException { +// final ReadWriteSplittingPlugin plugin = +// new ReadWriteSplittingPlugin(mockPluginService, defaultProps); +// plugin.switchConnectionIfRequired(true); +// +// assertEquals(mockReaderConn1, plugin.getReaderConnection()); +// } +// +// @Test +// public void testSetReadOnly_false() throws SQLException { +// when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); +// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// mockReaderConn1); +// plugin.switchConnectionIfRequired(false); +// +// assertEquals(mockWriterConn, plugin.getWriterConnection()); +// } +// +// @Test +// public void testSetReadOnly_true_oneHost() throws SQLException { +// when(this.mockPluginService.getHosts()).thenReturn(Collections.singletonList(writerHostSpec)); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// null); +// plugin.switchConnectionIfRequired(true); +// +// verify(mockPluginService, times(0)) +// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); +// assertEquals(mockWriterConn, plugin.getWriterConnection()); +// assertNull(plugin.getReaderConnection()); +// } +// +// @Test +// public void testSetReadOnly_false_writerConnectionFails() throws SQLException { +// when(mockPluginService.connect(eq(writerHostSpec), eq(defaultProps), any())) +// .thenThrow(SQLException.class); +// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); +// when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// mockReaderConn1); +// +// final SQLException e = +// assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); +// assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), e.getSQLState()); +// verify(mockPluginService, times(0)) +// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); +// } +// +// @Test +// public void testSetReadOnly_true_readerConnectionFailed() throws SQLException { +// when(this.mockPluginService.connect(eq(readerHostSpec1), eq(defaultProps), any())) +// .thenThrow(SQLException.class); +// when(this.mockPluginService.connect(eq(readerHostSpec2), eq(defaultProps), any())) +// .thenThrow(SQLException.class); +// when(this.mockPluginService.connect(eq(readerHostSpec3), eq(defaultProps), any())) +// .thenThrow(SQLException.class); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// null); +// plugin.switchConnectionIfRequired(true); +// +// verify(mockPluginService, times(0)) +// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); +// assertNull(plugin.getReaderConnection()); +// } +// +// @Test +// public void testSetReadOnlyOnClosedConnection() throws SQLException { +// when(mockPluginService.getCurrentConnection()).thenReturn(mockClosedWriterConn); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockClosedWriterConn, +// null); +// +// final SQLException e = +// assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(true)); +// assertEquals(SqlState.CONNECTION_NOT_OPEN.getState(), e.getSQLState()); +// verify(mockPluginService, times(0)) +// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); +// assertNull(plugin.getReaderConnection()); +// } +// +// @Test +// public void testExecute_failoverToNewWriter() throws SQLException { +// when(mockSqlFunction.call()).thenThrow(FailoverSuccessSQLException.class); +// when(mockPluginService.getCurrentConnection()).thenReturn(mockNewWriterConn); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// null); +// +// assertThrows( +// SQLException.class, +// () -> plugin.execute( +// ResultSet.class, +// SQLException.class, +// mockStatement, +// "Statement.executeQuery", +// mockSqlFunction, +// new Object[] { +// "begin"})); +// verify(mockWriterConn, times(1)).close(); +// } +// +// @Test +// public void testNotifyConnectionChange() { +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// null); +// +// final OldConnectionSuggestedAction suggestion = plugin.notifyConnectionChanged(mockChanges); +// +// assertEquals(mockWriterConn, plugin.getWriterConnection()); +// assertEquals(OldConnectionSuggestedAction.NO_OPINION, suggestion); +// } +// +// @Test +// public void testConnectNonInitialConnection() throws SQLException { +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// null); +// +// final Connection connection = +// plugin.connect(TEST_PROTOCOL, writerHostSpec, defaultProps, false, this.mockConnectFunc); +// +// assertEquals(mockWriterConn, connection); +// verify(mockConnectFunc).call(); +// verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); +// } +// +// @Test +// public void testConnectRdsInstanceUrl() throws SQLException { +// when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); +// when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// null); +// final Connection connection = plugin.connect( +// TEST_PROTOCOL, +// instanceUrlHostSpec, +// defaultProps, +// true, +// this.mockConnectFunc); +// +// assertEquals(mockReaderConn1, connection); +// verify(mockConnectFunc).call(); +// verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); +// } +// +// @Test +// public void testConnectReaderIpUrl() throws SQLException { +// when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); +// when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// null); +// final Connection connection = +// plugin.connect(TEST_PROTOCOL, ipUrlHostSpec, defaultProps, true, this.mockConnectFunc); +// +// assertEquals(mockReaderConn1, connection); +// verify(mockConnectFunc).call(); +// verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); +// } +// +// @Test +// public void testConnectClusterUrl() throws SQLException { +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// null); +// final Connection connection = +// plugin.connect(TEST_PROTOCOL, clusterUrlHostSpec, defaultProps, true, this.mockConnectFunc); +// +// assertEquals(mockWriterConn, connection); +// verify(mockConnectFunc).call(); +// verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); +// } +// +// @Test +// public void testConnect_errorUpdatingHostSpec() throws SQLException { +// when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); +// when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(null); +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// null); +// +// assertThrows( +// SQLException.class, +// () -> plugin.connect( +// TEST_PROTOCOL, +// ipUrlHostSpec, +// defaultProps, +// true, +// this.mockConnectFunc)); +// verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); +// } +// +// @Test +// public void testExecuteClearWarnings() throws SQLException { +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// mockReaderConn1); +// +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// mockStatement, +// "Connection.clearWarnings", +// mockSqlFunction, +// new Object[] {} +// ); +// verify(mockWriterConn, times(1)).clearWarnings(); +// verify(mockReaderConn1, times(1)).clearWarnings(); +// } +// +// @Test +// public void testExecuteClearWarningsOnClosedConnectionsIsNotCalled() throws SQLException { +// when(mockWriterConn.isClosed()).thenReturn(true); +// when(mockReaderConn1.isClosed()).thenReturn(true); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// mockReaderConn1); +// +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// mockStatement, +// "Connection.clearWarnings", +// mockSqlFunction, +// new Object[] {} +// ); +// verify(mockWriterConn, never()).clearWarnings(); +// verify(mockReaderConn1, never()).clearWarnings(); +// } +// +// @Test +// public void testExecuteClearWarningsOnNullConnectionsIsNotCalled() throws SQLException { +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// null); +// +// // calling clearWarnings() on nullified connection would throw an exception +// assertDoesNotThrow(() -> { +// plugin.execute( +// ResultSet.class, +// SQLException.class, +// mockStatement, +// "Connection.clearWarnings", +// mockSqlFunction, +// new Object[] {} +// ); +// }); +// } +// +// @Test +// public void testClosePooledReaderConnectionAfterSetReadOnly() throws SQLException { +// doReturn(writerHostSpec) +// .doReturn(writerHostSpec) +// .doReturn(readerHostSpec1) +// .when(this.mockPluginService).getCurrentHostSpec(); +// doReturn(mockReaderConn1).when(mockPluginService).connect(readerHostSpec1, null); +// when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); +// when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// mockWriterConn, +// null); +// final ReadWriteSplittingPlugin spyPlugin = spy(plugin); +// +// spyPlugin.switchConnectionIfRequired(true); +// spyPlugin.switchConnectionIfRequired(false); +// +// verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockReaderConn1)); +// } +// +// @Test +// public void testClosePooledWriterConnectionAfterSetReadOnly() throws SQLException { +// doReturn(writerHostSpec) +// .doReturn(writerHostSpec) +// .doReturn(readerHostSpec1) +// .doReturn(readerHostSpec1) +// .doReturn(writerHostSpec) +// .when(this.mockPluginService).getCurrentHostSpec(); +// doReturn(mockWriterConn).when(mockPluginService).connect(writerHostSpec, null); +// when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); +// when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); +// +// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( +// mockPluginService, +// defaultProps, +// mockHostListProviderService, +// null, +// null); +// final ReadWriteSplittingPlugin spyPlugin = spy(plugin); +// +// spyPlugin.switchConnectionIfRequired(true); +// spyPlugin.switchConnectionIfRequired(false); +// spyPlugin.switchConnectionIfRequired(true); +// +// verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockWriterConn)); +// } +// +// private static HikariConfig getHikariConfig(HostSpec hostSpec, Properties props) { +// final HikariConfig config = new HikariConfig(); +// config.setMaximumPoolSize(3); +// config.setInitializationFailTimeout(75000); +// config.setConnectionTimeout(10000); +// return config; +// } +// +// private static String getPoolKey(HostSpec hostSpec, Properties props) { +// final String user = props.getProperty(PropertyDefinition.USER.name); +// final String somePropertyValue = props.getProperty("somePropertyValue"); +// return hostSpec.getUrl() + user + somePropertyValue; +// } +// } diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index 450b494a7..a3a661c56 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -1,315 +1,315 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package software.amazon.jdbc.util.monitoring; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.spy; - -import java.sql.SQLException; -import java.util.Collections; -import java.util.HashSet; -import java.util.Properties; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.events.EventPublisher; -import software.amazon.jdbc.util.storage.StorageService; -import software.amazon.jdbc.util.telemetry.TelemetryFactory; - -class MonitorServiceImplTest { - @Mock FullServicesContainer mockServicesContainer; - @Mock StorageService mockStorageService; - @Mock ConnectionProvider mockConnectionProvider; - @Mock TelemetryFactory mockTelemetryFactory; - @Mock TargetDriverDialect mockTargetDriverDialect; - @Mock Dialect mockDbDialect; - @Mock EventPublisher mockPublisher; - String url = "jdbc:postgresql://somehost/somedb"; - String protocol = "someProtocol"; - Properties props = new Properties(); - MonitorServiceImpl spyMonitorService; - private AutoCloseable closeable; - - @BeforeEach - void setUp() throws SQLException { - closeable = MockitoAnnotations.openMocks(this); - spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); - doNothing().when(spyMonitorService).initCleanupThread(anyInt()); - doReturn(mockServicesContainer).when(spyMonitorService).getNewServicesContainer( - eq(mockStorageService), - eq(mockConnectionProvider), - eq(mockTelemetryFactory), - eq(url), - eq(protocol), - eq(mockTargetDriverDialect), - eq(mockDbDialect), - eq(props)); - } - - @AfterEach - void tearDown() throws Exception { - closeable.close(); - spyMonitorService.releaseResources(); - } - - @Test - public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, - (mockServicesContainer) -> new NoOpMonitor(30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - monitor.state.set(MonitorState.ERROR); - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(newMonitor); - assertNotEquals(monitor, newMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, newMonitor.getState()); - } - - @Test - public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - 1, // heartbeat times out immediately - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, - (mockServicesContainer) -> new NoOpMonitor(30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(newMonitor); - assertNotEquals(monitor, newMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, newMonitor.getState()); - } - - @Test - public void testMonitorExpired() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, - (mockServicesContainer) -> new NoOpMonitor(30) - ); - - Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); - assertNotNull(storedMonitor); - assertEquals(monitor, storedMonitor); - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - assertEquals(MonitorState.RUNNING, monitor.getState()); - - // checkMonitors() should detect the expiration timeout and stop/remove the monitor. - spyMonitorService.checkMonitors(); - - assertEquals(MonitorState.STOPPED, monitor.getState()); - - Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); - // monitor should have been removed when checkMonitors() was called. - assertNull(newMonitor); - } - - @Test - public void testMonitorMismatch() { - assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( - CustomEndpointMonitorImpl.class, - "testMonitor", - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, - // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor - // service should detect this and throw an exception. - (mockServicesContainer) -> new NoOpMonitor(30) - )); - } - - @Test - public void testRemove() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, - (mockServicesContainer) -> new NoOpMonitor(30) - ); - assertNotNull(monitor); - - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); - assertEquals(monitor, removedMonitor); - assertEquals(MonitorState.RUNNING, monitor.getState()); - } - - @Test - public void testStopAndRemove() throws SQLException, InterruptedException { - spyMonitorService.registerMonitorTypeIfAbsent( - NoOpMonitor.class, - TimeUnit.MINUTES.toNanos(1), - TimeUnit.MINUTES.toNanos(1), - // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this - // indicates it is not being used. - new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), - null - ); - - String key = "testMonitor"; - NoOpMonitor monitor = spyMonitorService.runIfAbsent( - NoOpMonitor.class, - key, - mockStorageService, - mockTelemetryFactory, - mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, - (mockServicesContainer) -> new NoOpMonitor(30) - ); - assertNotNull(monitor); - - // need to wait to give time for the monitor executor to start the monitor thread. - TimeUnit.MILLISECONDS.sleep(250); - spyMonitorService.stopAndRemove(NoOpMonitor.class, key); - assertNull(spyMonitorService.get(NoOpMonitor.class, key)); - assertEquals(MonitorState.STOPPED, monitor.getState()); - } - - static class NoOpMonitor extends AbstractMonitor { - protected NoOpMonitor(long terminationTimeoutSec) { - super(terminationTimeoutSec); - } - - @Override - public void monitor() { - // do nothing. - } - } -} +// /* +// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"). +// * You may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package software.amazon.jdbc.util.monitoring; +// +// import static org.junit.jupiter.api.Assertions.assertEquals; +// import static org.junit.jupiter.api.Assertions.assertNotEquals; +// import static org.junit.jupiter.api.Assertions.assertNotNull; +// import static org.junit.jupiter.api.Assertions.assertNull; +// import static org.junit.jupiter.api.Assertions.assertThrows; +// import static org.mockito.ArgumentMatchers.anyInt; +// import static org.mockito.ArgumentMatchers.eq; +// import static org.mockito.Mockito.doNothing; +// import static org.mockito.Mockito.doReturn; +// import static org.mockito.Mockito.spy; +// +// import java.sql.SQLException; +// import java.util.Collections; +// import java.util.HashSet; +// import java.util.Properties; +// import java.util.concurrent.TimeUnit; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.mockito.Mock; +// import org.mockito.MockitoAnnotations; +// import software.amazon.jdbc.ConnectionProvider; +// import software.amazon.jdbc.dialect.Dialect; +// import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; +// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +// import software.amazon.jdbc.util.FullServicesContainer; +// import software.amazon.jdbc.util.events.EventPublisher; +// import software.amazon.jdbc.util.storage.StorageService; +// import software.amazon.jdbc.util.telemetry.TelemetryFactory; +// +// class MonitorServiceImplTest { +// @Mock FullServicesContainer mockServicesContainer; +// @Mock StorageService mockStorageService; +// @Mock ConnectionProvider mockConnectionProvider; +// @Mock TelemetryFactory mockTelemetryFactory; +// @Mock TargetDriverDialect mockTargetDriverDialect; +// @Mock Dialect mockDbDialect; +// @Mock EventPublisher mockPublisher; +// String url = "jdbc:postgresql://somehost/somedb"; +// String protocol = "someProtocol"; +// Properties props = new Properties(); +// MonitorServiceImpl spyMonitorService; +// private AutoCloseable closeable; +// +// @BeforeEach +// void setUp() throws SQLException { +// closeable = MockitoAnnotations.openMocks(this); +// spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); +// doNothing().when(spyMonitorService).initCleanupThread(anyInt()); +// doReturn(mockServicesContainer).when(spyMonitorService).getNewServicesContainer( +// eq(mockStorageService), +// eq(mockConnectionProvider), +// eq(mockTelemetryFactory), +// eq(url), +// eq(protocol), +// eq(mockTargetDriverDialect), +// eq(mockDbDialect), +// eq(props)); +// } +// +// @AfterEach +// void tearDown() throws Exception { +// closeable.close(); +// spyMonitorService.releaseResources(); +// } +// +// @Test +// public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// url, +// protocol, +// mockTargetDriverDialect, +// mockDbDialect, +// props, +// (mockServicesContainer) -> new NoOpMonitor(30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// monitor.state.set(MonitorState.ERROR); +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(newMonitor); +// assertNotEquals(monitor, newMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, newMonitor.getState()); +// } +// +// @Test +// public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// 1, // heartbeat times out immediately +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// url, +// protocol, +// mockTargetDriverDialect, +// mockDbDialect, +// props, +// (mockServicesContainer) -> new NoOpMonitor(30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(newMonitor); +// assertNotEquals(monitor, newMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, newMonitor.getState()); +// } +// +// @Test +// public void testMonitorExpired() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// url, +// protocol, +// mockTargetDriverDialect, +// mockDbDialect, +// props, +// (mockServicesContainer) -> new NoOpMonitor(30) +// ); +// +// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// assertNotNull(storedMonitor); +// assertEquals(monitor, storedMonitor); +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// +// // checkMonitors() should detect the expiration timeout and stop/remove the monitor. +// spyMonitorService.checkMonitors(); +// +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// +// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); +// // monitor should have been removed when checkMonitors() was called. +// assertNull(newMonitor); +// } +// +// @Test +// public void testMonitorMismatch() { +// assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( +// CustomEndpointMonitorImpl.class, +// "testMonitor", +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// url, +// protocol, +// mockTargetDriverDialect, +// mockDbDialect, +// props, +// // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor +// // service should detect this and throw an exception. +// (mockServicesContainer) -> new NoOpMonitor(30) +// )); +// } +// +// @Test +// public void testRemove() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// url, +// protocol, +// mockTargetDriverDialect, +// mockDbDialect, +// props, +// (mockServicesContainer) -> new NoOpMonitor(30) +// ); +// assertNotNull(monitor); +// +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); +// assertEquals(monitor, removedMonitor); +// assertEquals(MonitorState.RUNNING, monitor.getState()); +// } +// +// @Test +// public void testStopAndRemove() throws SQLException, InterruptedException { +// spyMonitorService.registerMonitorTypeIfAbsent( +// NoOpMonitor.class, +// TimeUnit.MINUTES.toNanos(1), +// TimeUnit.MINUTES.toNanos(1), +// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this +// // indicates it is not being used. +// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), +// null +// ); +// +// String key = "testMonitor"; +// NoOpMonitor monitor = spyMonitorService.runIfAbsent( +// NoOpMonitor.class, +// key, +// mockStorageService, +// mockTelemetryFactory, +// mockConnectionProvider, +// url, +// protocol, +// mockTargetDriverDialect, +// mockDbDialect, +// props, +// (mockServicesContainer) -> new NoOpMonitor(30) +// ); +// assertNotNull(monitor); +// +// // need to wait to give time for the monitor executor to start the monitor thread. +// TimeUnit.MILLISECONDS.sleep(250); +// spyMonitorService.stopAndRemove(NoOpMonitor.class, key); +// assertNull(spyMonitorService.get(NoOpMonitor.class, key)); +// assertEquals(MonitorState.STOPPED, monitor.getState()); +// } +// +// static class NoOpMonitor extends AbstractMonitor { +// protected NoOpMonitor(long terminationTimeoutSec) { +// super(terminationTimeoutSec); +// } +// +// @Override +// public void monitor() { +// // do nothing. +// } +// } +// } From 060dab68c3bf57e3f91f194ef35eb18fffa5576f Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 17 Sep 2025 17:50:51 -0700 Subject: [PATCH 44/54] Change ConnectionContext#getPropsCopy to getProps --- .../amazon/jdbc/C3P0PooledConnectionProvider.java | 9 +++++---- .../amazon/jdbc/DataSourceConnectionProvider.java | 9 +++++---- .../software/amazon/jdbc/DriverConnectionProvider.java | 10 +++++----- .../amazon/jdbc/HikariPooledConnectionProvider.java | 5 ++--- .../software/amazon/jdbc/PartialPluginService.java | 4 ++-- .../java/software/amazon/jdbc/PluginServiceImpl.java | 8 ++++---- .../software/amazon/jdbc/dialect/DialectManager.java | 2 +- .../ConnectionStringHostListProvider.java | 2 +- .../plugin/AuroraInitialConnectionStrategyPlugin.java | 2 +- .../jdbc/plugin/AwsSecretsManagerConnectionPlugin.java | 4 ++-- .../jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java | 2 +- .../plugin/failover2/FailoverConnectionPlugin.java | 2 +- .../jdbc/plugin/federatedauth/FederatedAuthPlugin.java | 4 ++-- .../jdbc/plugin/federatedauth/OktaAuthPlugin.java | 4 ++-- .../jdbc/plugin/iam/IamAuthConnectionPlugin.java | 2 +- .../plugin/limitless/LimitlessConnectionPlugin.java | 2 +- .../jdbc/plugin/staledns/AuroraStaleDnsHelper.java | 2 +- .../java/software/amazon/jdbc/util/ServiceUtility.java | 2 +- .../amazon/jdbc/util/connection/ConnectionContext.java | 5 ++--- .../jdbc/util/connection/ConnectionServiceImpl.java | 2 +- 20 files changed, 41 insertions(+), 41 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java index 73b147b90..eb7971b2d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java @@ -33,6 +33,7 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.targetdriverdialect.ConnectInfo; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.connection.ConnectionContext; import software.amazon.jdbc.util.storage.SlidingExpirationCache; @@ -80,16 +81,16 @@ public HostSpec getHostSpecByStrategy(@NonNull List hosts, @NonNull Ho public Connection connect( @NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) throws SQLException { Dialect dialect = connectionContext.getDbDialect(); - Properties props = connectionContext.getPropsCopy(); - dialect.prepareConnectProperties(props, connectionContext.getProtocol(), hostSpec); + Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); + dialect.prepareConnectProperties(propsCopy, connectionContext.getProtocol(), hostSpec); final ComboPooledDataSource ds = databasePools.computeIfAbsent( hostSpec.getUrl(), - (key) -> createDataSource(connectionContext, hostSpec, props), + (key) -> createDataSource(connectionContext, hostSpec, propsCopy), poolExpirationCheckNanos ); - ds.setPassword(props.getProperty(PropertyDefinition.PASSWORD.name)); + ds.setPassword(propsCopy.getProperty(PropertyDefinition.PASSWORD.name)); return ds.getConnection(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java index e2a2ed6c5..353e0847d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java @@ -32,6 +32,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.exceptions.SQLLoginException; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.WrapperUtils; @@ -100,8 +101,8 @@ public HostSpec getHostSpecByStrategy( @Override public Connection connect( final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec) throws SQLException { - final Properties copy = connectionContext.getPropsCopy(); - connectionContext.getDbDialect().prepareConnectProperties(copy, connectionContext.getProtocol(), hostSpec); + final Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); + connectionContext.getDbDialect().prepareConnectProperties(propsCopy, connectionContext.getProtocol(), hostSpec); Connection conn; @@ -110,7 +111,7 @@ public Connection connect( LOGGER.finest(() -> "Use a separate DataSource object to create a connection."); // use a new data source instance to instantiate a connection final DataSource ds = createDataSource(); - conn = this.openConnection(ds, connectionContext, hostSpec, copy); + conn = this.openConnection(ds, connectionContext, hostSpec, propsCopy); } else { @@ -119,7 +120,7 @@ public Connection connect( this.lock.lock(); LOGGER.finest(() -> "Use main DataSource object to create a connection."); try { - conn = this.openConnection(this.dataSource, connectionContext, hostSpec, copy); + conn = this.openConnection(this.dataSource, connectionContext, hostSpec, propsCopy); } finally { this.lock.unlock(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java index 62b27359a..403ff279e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java @@ -97,11 +97,11 @@ public HostSpec getHostSpecByStrategy( @Override public Connection connect(final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec) throws SQLException { - final Properties copy = connectionContext.getPropsCopy(); + final Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); final ConnectInfo connectInfo = - connectionContext.getDriverDialect().prepareConnectInfo(connectionContext.getProtocol(), hostSpec, copy); + connectionContext.getDriverDialect().prepareConnectInfo(connectionContext.getProtocol(), hostSpec, propsCopy); - connectionContext.getDbDialect().prepareConnectProperties(copy, connectionContext.getProtocol(), hostSpec); + connectionContext.getDbDialect().prepareConnectProperties(propsCopy, connectionContext.getProtocol(), hostSpec); LOGGER.finest(() -> "Connecting to " + connectInfo.url + PropertyUtils.logProperties( PropertyUtils.maskProperties(connectInfo.props), @@ -113,7 +113,7 @@ public Connection connect(final @NonNull ConnectionContext connectionContext, fi } catch (Throwable throwable) { - if (!PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(copy)) { + if (!PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(propsCopy)) { throw throwable; } @@ -159,7 +159,7 @@ public Connection connect(final @NonNull ConnectionContext connectionContext, fi .build(); final ConnectInfo fixedConnectInfo = connectionContext.getDriverDialect().prepareConnectInfo( - connectionContext.getProtocol(), connectionHostSpec, copy); + connectionContext.getProtocol(), connectionHostSpec, propsCopy); LOGGER.finest(() -> "Connecting to " + fixedConnectInfo.url + " after correcting the hostname from " + originalHost diff --git a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java index 1f68ce0cf..fe74a6a5e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java @@ -205,7 +205,7 @@ public HikariPooledConnectionProvider( @Override public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { if (this.acceptsUrlFunc != null) { - return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionContext.getPropsCopy()); + return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionContext.getProps()); } final RdsUrlType urlType = rdsUtils.identifyRdsType(hostSpec.getHost()); @@ -240,7 +240,7 @@ public HostSpec getHostSpecByStrategy( @Override public Connection connect(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) throws SQLException { - final Properties propsCopy = connectionContext.getPropsCopy(); + final Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); HostSpec connectionHostSpec = hostSpec; if (PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(propsCopy) @@ -315,7 +315,6 @@ protected void configurePool( final ConnectionContext connectionContext, final HostSpec hostSpec, final Properties connectionProps) { - final Properties copy = PropertyUtils.copyProperties(connectionProps); ConnectInfo connectInfo; diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 74fe058cf..94445700e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -624,12 +624,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getPropsCopy())); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getProps())); } @Override public Properties getProperties() { - return this.connectionContext.getPropsCopy(); + return this.connectionContext.getProps(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 007b7043b..afbb9f853 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -125,7 +125,7 @@ public PluginServiceImpl( this.sessionStateService = sessionStateService != null ? sessionStateService - : new SessionStateServiceImpl(this, this.connectionContext.getPropsCopy()); + : new SessionStateServiceImpl(this, this.connectionContext.getProps()); this.exceptionHandler = this.configurationProfile != null && this.configurationProfile.getExceptionHandler() != null ? this.configurationProfile.getExceptionHandler() @@ -291,7 +291,7 @@ public EnumSet setCurrentConnection( this.setInTransaction(false); if (isInTransaction - && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionContext.getPropsCopy())) { + && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionContext.getProps())) { try { oldConnection.rollback(); } catch (final SQLException e) { @@ -747,12 +747,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getPropsCopy())); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getProps())); } @Override public Properties getProperties() { - return this.connectionContext.getPropsCopy(); + return this.connectionContext.getProps(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java index 0eecadbfa..aa24188fe 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java @@ -131,7 +131,7 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th return this.dialect; } - final String userDialectSetting = DIALECT.getString(connectionContext.getPropsCopy()); + final String userDialectSetting = DIALECT.getString(connectionContext.getProps()); final String dialectCode = !StringUtils.isNullOrEmpty(userDialectSetting) ? userDialectSetting : knownEndpointDialects.get(connectionContext.getUrl()); diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java index 5a4d4fc9c..3c985b96f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java @@ -60,7 +60,7 @@ public ConnectionStringHostListProvider( final @NonNull HostListProviderService hostListProviderService, final @NonNull ConnectionUrlParser connectionUrlParser) { this.connectionContext = connectionContext; - this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionContext.getPropsCopy()); + this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionContext.getProps()); this.connectionUrlParser = connectionUrlParser; this.hostListProviderService = hostListProviderService; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java index 6d8e6d558..610ae776a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java @@ -139,7 +139,7 @@ public Connection connect( final JdbcCallable connectFunc) throws SQLException { final RdsUrlType type = this.rdsUtils.identifyRdsType(hostSpec.getHost()); - final Properties props = connectionContext.getPropsCopy(); + final Properties props = connectionContext.getProps(); if (type == RdsUrlType.RDS_WRITER_CLUSTER || isInitialConnection && this.verifyOpenedConnectionType == VerifyOpenedConnectionType.WRITER) { Connection writerCandidateConn = this.getVerifiedWriterConnection(props, isInitialConnection, connectFunc); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java index 14960187f..b4a815ef6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java @@ -190,7 +190,7 @@ public Connection connect( final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getPropsCopy(), connectFunc); + return connectInternal(hostSpec, connectionContext.getProps(), connectFunc); } private Connection connectInternal(HostSpec hostSpec, Properties props, @@ -231,7 +231,7 @@ public Connection forceConnect( final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getPropsCopy(), forceConnectFunc); + return connectInternal(hostSpec, connectionContext.getProps(), forceConnectFunc); } /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java index 0fc3310a8..fea69250b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java @@ -600,7 +600,7 @@ protected void initHostListProvider() { } final ConnectionContext originalContext = this.pluginService.getConnectionContext(); - final Properties hostListProperties = originalContext.getPropsCopy(); + final Properties hostListProperties = originalContext.getProps(); // Need to instantiate a separate HostListProvider with // a special unique clusterId to avoid interference with other HostListProviders opened for this cluster. diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java index 7eb64085b..de45eeef4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java @@ -734,7 +734,7 @@ public Connection connect( this.initFailoverMode(); Connection conn = null; - Properties props = connectionContext.getPropsCopy(); + Properties props = connectionContext.getProps(); if (!ENABLE_CONNECT_FAILOVER.getBoolean(props)) { return this.staleDnsHelper.getVerifiedConnection( isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java index f40e4be72..d1bbab792 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -157,7 +157,7 @@ public Connection connect( final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getPropsCopy(), connectFunc); + return connectInternal(hostSpec, connectionContext.getProps(), connectFunc); } @Override @@ -167,7 +167,7 @@ public Connection forceConnect( final boolean isInitialConnection, final @NonNull JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getPropsCopy(), forceConnectFunc); + return connectInternal(hostSpec, connectionContext.getProps(), forceConnectFunc); } private Connection connectInternal( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java index b09b0d9e1..510c508cf 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java @@ -138,7 +138,7 @@ public Connection connect( final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getPropsCopy(), connectFunc); + return connectInternal(hostSpec, connectionContext.getProps(), connectFunc); } @Override @@ -148,7 +148,7 @@ public Connection forceConnect( boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getPropsCopy(), forceConnectFunc); + return connectInternal(hostSpec, connectionContext.getProps(), forceConnectFunc); } private Connection connectInternal(final HostSpec hostSpec, final Properties props, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java index 64ac50ba7..eec9d7fc9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java @@ -118,7 +118,7 @@ private Connection connectInternal( ConnectionContext connectionContext, HostSpec hostSpec, JdbcCallable connectFunc) throws SQLException { - Properties props = connectionContext.getPropsCopy(); + Properties props = connectionContext.getProps(); if (StringUtils.isNullOrEmpty(PropertyDefinition.USER.getString(props))) { throw new SQLException(PropertyDefinition.USER.name + " is null or empty."); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java index be21a576d..8284dc086 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java @@ -130,7 +130,7 @@ public Connection connect( final LimitlessConnectionContext context = new LimitlessConnectionContext( hostSpec, - connectionContext.getPropsCopy(), + connectionContext.getProps(), conn, connectFunc, null, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java index 4043edea3..65e31b694 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java @@ -147,7 +147,7 @@ public Connection getVerifiedConnection( ); } - final Connection writerConn = this.pluginService.connect(this.writerHostSpec, connectionContext.getPropsCopy()); + final Connection writerConn = this.pluginService.connect(this.writerHostSpec, connectionContext.getProps()); if (isInitialConnection) { hostListProviderService.setInitialConnectionHostSpec(this.writerHostSpec); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java index 6a5c4a191..f194671af 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java @@ -71,7 +71,7 @@ public FullServicesContainer createServiceContainer( servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); - pluginManager.init(servicesContainer, connectionContext.getPropsCopy(), partialPluginService, null); + pluginManager.init(servicesContainer, connectionContext.getProps(), partialPluginService, null); return servicesContainer; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java index 9223a3853..8110fccb3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java @@ -20,7 +20,6 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.ConnectionUrlParser; -import software.amazon.jdbc.util.PropertyUtils; public class ConnectionContext { protected static final ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); @@ -53,8 +52,8 @@ public TargetDriverDialect getDriverDialect() { return driverDialect; } - public Properties getPropsCopy() { - return PropertyUtils.copyProperties(props); + public Properties getProps() { + return props; } public Dialect getDbDialect() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index e9232cbad..306bf7303 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -72,7 +72,7 @@ public ConnectionServiceImpl( servicesContainer.setPluginManagerService(partialPluginService); this.pluginService = partialPluginService; - this.pluginManager.init(servicesContainer, this.connectionContext.getPropsCopy(), partialPluginService, null); + this.pluginManager.init(servicesContainer, this.connectionContext.getProps(), partialPluginService, null); } @Override From f0410f981e68e078eb3d2742198f605440a92451 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 19 Sep 2025 13:43:25 -0700 Subject: [PATCH 45/54] Fix test failures --- .../java/software/amazon/jdbc/PluginServiceImpl.java | 6 +++--- .../jdbc/hostlistprovider/RdsHostListProvider.java | 9 ++++----- .../monitoring/MonitoringRdsHostListProvider.java | 4 ++-- .../monitoring/MonitoringRdsMultiAzHostListProvider.java | 2 +- .../java/software/amazon/jdbc/util/ServiceUtility.java | 4 +++- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index afbb9f853..2bfaf6b98 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -696,13 +696,13 @@ public TargetDriverDialect getTargetDriverDialect() { public void updateDialect(final @NonNull Connection connection) throws SQLException { final Dialect originalDialect = this.connectionContext.getDbDialect(); - Dialect dialect = this.dialectProvider.getDialect( + Dialect currentDialect = this.dialectProvider.getDialect( this.connectionContext.getProtocol(), this.initialConnectionHostSpec, connection); - if (originalDialect == this.connectionContext.getDbDialect()) { + if (originalDialect == currentDialect) { return; } - this.connectionContext.setDbDialect(dialect); + this.connectionContext.setDbDialect(currentDialect); final HostListProviderSupplier supplier = this.connectionContext.getDbDialect().getHostListProvider(); this.setHostListProvider(supplier.getProvider(this.connectionContext, this.servicesContainer)); this.refreshHostList(connection); diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java index 0883ff909..cb7242423 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java @@ -117,8 +117,6 @@ public class RdsHostListProvider implements DynamicHostListProvider { protected volatile boolean isInitialized = false; - protected Properties properties; - static { PropertyDefinition.registerPluginProperties(RdsHostListProvider.class); } @@ -161,11 +159,12 @@ protected void init() throws SQLException { this.clusterId = UUID.randomUUID().toString(); this.isPrimaryClusterId = false; + Properties props = this.connectionContext.getProps(); this.refreshRateNano = - TimeUnit.MILLISECONDS.toNanos(CLUSTER_TOPOLOGY_REFRESH_RATE_MS.getInteger(properties)); + TimeUnit.MILLISECONDS.toNanos(CLUSTER_TOPOLOGY_REFRESH_RATE_MS.getInteger(props)); HostSpecBuilder hostSpecBuilder = this.hostListProviderService.getHostSpecBuilder(); - String clusterInstancePattern = CLUSTER_INSTANCE_HOST_PATTERN.getString(this.properties); + String clusterInstancePattern = CLUSTER_INSTANCE_HOST_PATTERN.getString(props); if (clusterInstancePattern != null) { this.clusterInstanceTemplate = ConnectionUrlParser.parseHostPortPair(clusterInstancePattern, () -> hostSpecBuilder); @@ -182,7 +181,7 @@ protected void init() throws SQLException { this.rdsUrlType = rdsHelper.identifyRdsType(this.initialHostSpec.getHost()); - final String clusterIdSetting = CLUSTER_ID.getString(this.properties); + final String clusterIdSetting = CLUSTER_ID.getString(props); if (!StringUtils.isNullOrEmpty(clusterIdSetting)) { this.clusterId = clusterIdSetting; } else if (rdsUrlType == RdsUrlType.RDS_PROXY) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java index 97d2f7f0b..6eb96d5e6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java @@ -67,7 +67,7 @@ public MonitoringRdsHostListProvider( this.pluginService = servicesContainer.getPluginService(); this.writerTopologyQuery = writerTopologyQuery; this.highRefreshRateNano = TimeUnit.MILLISECONDS.toNanos( - CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.getLong(this.properties)); + CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.getLong(this.connectionContext.getProps())); } public static void clearCache() { @@ -91,7 +91,7 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { this.servicesContainer, this.clusterId, this.initialHostSpec, - this.properties, + this.connectionContext.getProps(), this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java index 6df75253a..c8626e344 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java @@ -60,7 +60,7 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { servicesContainer, this.clusterId, this.initialHostSpec, - this.properties, + this.connectionContext.getProps(), this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java index f194671af..a76432c2a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java @@ -17,6 +17,7 @@ package software.amazon.jdbc.util; import java.sql.SQLException; +import java.util.Properties; import java.util.concurrent.locks.ReentrantLock; import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; @@ -71,7 +72,8 @@ public FullServicesContainer createServiceContainer( servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); - pluginManager.init(servicesContainer, connectionContext.getProps(), partialPluginService, null); + Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); + pluginManager.init(servicesContainer, propsCopy, partialPluginService, null); return servicesContainer; } } From a612b0328bb1dc5ad565abcaaba8dfdbec0d130a Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 19 Sep 2025 15:09:14 -0700 Subject: [PATCH 46/54] ConnectionContext#getUrl -> ConnectionContext#getInitialConnectionString --- .../testplugin/BenchmarkPlugin.java | 2 +- .../amazon/jdbc/PartialPluginService.java | 2 +- .../amazon/jdbc/PluginServiceImpl.java | 2 +- .../amazon/jdbc/dialect/DialectManager.java | 12 +++++----- .../ConnectionStringHostListProvider.java | 4 ++-- .../hostlistprovider/RdsHostListProvider.java | 4 ++-- .../util/connection/ConnectionContext.java | 23 ++++++++++--------- 7 files changed, 25 insertions(+), 24 deletions(-) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java index 2979a6555..6e94a7b7a 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java @@ -100,7 +100,7 @@ public void initHostProvider( ConnectionContext connectionContext, HostListProviderService hostListProviderService, JdbcCallable initHostProviderFunc) { - LOGGER.finer(() -> String.format("initHostProvider=''%s''", connectionContext.getUrl())); + LOGGER.finer(() -> String.format("initHostProvider=''%s''", connectionContext.getInitialConnectionString())); resources.add("initHostProvider"); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 94445700e..6b8cc318b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -169,7 +169,7 @@ public ConnectionContext getConnectionContext() { @Override public String getOriginalUrl() { - return this.connectionContext.getUrl(); + return this.connectionContext.getInitialConnectionString(); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 2bfaf6b98..82772eefb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -191,7 +191,7 @@ public ConnectionContext getConnectionContext() { @Override public String getOriginalUrl() { - return this.connectionContext.getUrl(); + return this.connectionContext.getInitialConnectionString(); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java index aa24188fe..771ed8c1d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java @@ -134,7 +134,7 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th final String userDialectSetting = DIALECT.getString(connectionContext.getProps()); final String dialectCode = !StringUtils.isNullOrEmpty(userDialectSetting) ? userDialectSetting - : knownEndpointDialects.get(connectionContext.getUrl()); + : knownEndpointDialects.get(connectionContext.getInitialConnectionString()); if (!StringUtils.isNullOrEmpty(dialectCode)) { final Dialect userDialect = knownDialectsByCode.get(dialectCode); @@ -153,15 +153,15 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th throw new IllegalArgumentException("protocol"); } - String host = connectionContext.getUrl(); + String connectionString = connectionContext.getInitialConnectionString(); final List hosts = this.connectionUrlParser.getHostsFromConnectionUrl( - connectionContext.getUrl(), true, pluginService::getHostSpecBuilder); + connectionContext.getInitialConnectionString(), true, pluginService::getHostSpecBuilder); if (!Utils.isNullOrEmpty(hosts)) { - host = hosts.get(0).getHost(); + connectionString = hosts.get(0).getHost(); } if (connectionContext.getProtocol().contains("mysql")) { - RdsUrlType type = this.rdsHelper.identifyRdsType(host); + RdsUrlType type = this.rdsHelper.identifyRdsType(connectionString); if (type.isRdsCluster()) { this.canUpdate = true; this.dialectCode = DialectCodes.AURORA_MYSQL; @@ -183,7 +183,7 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th } if (connectionContext.getProtocol().contains("postgresql")) { - RdsUrlType type = this.rdsHelper.identifyRdsType(host); + RdsUrlType type = this.rdsHelper.identifyRdsType(connectionString); if (RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP.equals(type)) { this.canUpdate = false; this.dialectCode = DialectCodes.AURORA_PG; diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java index 3c985b96f..4311917f2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java @@ -71,12 +71,12 @@ private void init() throws SQLException { } this.hostList.addAll( this.connectionUrlParser.getHostsFromConnectionUrl( - this.connectionContext.getUrl(), + this.connectionContext.getInitialConnectionString(), this.isSingleWriterConnectionString, () -> this.hostListProviderService.getHostSpecBuilder())); if (this.hostList.isEmpty()) { throw new SQLException(Messages.get("ConnectionStringHostListProvider.parsedListEmpty", - new Object[] {this.connectionContext.getUrl()})); + new Object[] {this.connectionContext.getInitialConnectionString()})); } this.hostListProviderService.setInitialConnectionHostSpec(this.hostList.get(0)); this.isInitialized = true; diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java index cb7242423..6e043631f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java @@ -148,11 +148,11 @@ protected void init() throws SQLException { // initial topology is based on connection string this.initialHostList = - connectionUrlParser.getHostsFromConnectionUrl(this.connectionContext.getUrl(), false, + connectionUrlParser.getHostsFromConnectionUrl(this.connectionContext.getInitialConnectionString(), false, this.hostListProviderService::getHostSpecBuilder); if (this.initialHostList == null || this.initialHostList.isEmpty()) { throw new SQLException(Messages.get("RdsHostListProvider.parsedListEmpty", - new Object[] {this.connectionContext.getUrl()})); + new Object[] {this.connectionContext.getInitialConnectionString()})); } this.initialHostSpec = this.initialHostList.get(0); this.hostListProviderService.setInitialConnectionHostSpec(this.initialHostSpec); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java index 8110fccb3..a06b25802 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java @@ -23,41 +23,42 @@ public class ConnectionContext { protected static final ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); - protected final String url; + protected final String initialConnectionString; protected final String protocol; protected final TargetDriverDialect driverDialect; protected final Properties props; protected Dialect dbDialect; - public ConnectionContext(String url, TargetDriverDialect driverDialect, Properties props) { - this(url, connectionUrlParser.getProtocol(url), driverDialect, props); + public ConnectionContext(String initialConnectionString, TargetDriverDialect driverDialect, Properties props) { + this(initialConnectionString, connectionUrlParser.getProtocol(initialConnectionString), driverDialect, props); } - public ConnectionContext(String url, String protocol, TargetDriverDialect driverDialect, Properties props) { - this.url = url; + public ConnectionContext( + String initialConnectionString, String protocol, TargetDriverDialect driverDialect, Properties props) { + this.initialConnectionString = initialConnectionString; this.protocol = protocol; this.driverDialect = driverDialect; this.props = props; } - public String getUrl() { - return url; + public String getInitialConnectionString() { + return this.initialConnectionString; } public String getProtocol() { - return protocol; + return this.protocol; } public TargetDriverDialect getDriverDialect() { - return driverDialect; + return this.driverDialect; } public Properties getProps() { - return props; + return this.props; } public Dialect getDbDialect() { - return dbDialect; + return this.dbDialect; } public void setDbDialect(Dialect dbDialect) { From 7c253da275ed790ad248596a45c00475fd495776 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 19 Sep 2025 15:16:57 -0700 Subject: [PATCH 47/54] ConnectionContext -> ConnectionInfo --- .../testplugin/BenchmarkPlugin.java | 10 +-- docs/development-guide/LoadablePlugins.md | 4 +- .../jdbc/C3P0PooledConnectionProvider.java | 20 +++--- .../amazon/jdbc/ConnectionPlugin.java | 14 ++-- .../amazon/jdbc/ConnectionPluginManager.java | 20 +++--- .../amazon/jdbc/ConnectionProvider.java | 10 +-- .../jdbc/ConnectionProviderManager.java | 16 ++--- .../jdbc/DataSourceConnectionProvider.java | 26 +++---- .../amazon/jdbc/DriverConnectionProvider.java | 18 ++--- .../jdbc/HikariPooledConnectionProvider.java | 28 ++++---- .../amazon/jdbc/PartialPluginService.java | 48 ++++++------- .../software/amazon/jdbc/PluginService.java | 8 +-- .../amazon/jdbc/PluginServiceImpl.java | 68 +++++++++---------- .../jdbc/dialect/AuroraMysqlDialect.java | 6 +- .../amazon/jdbc/dialect/AuroraPgDialect.java | 6 +- .../amazon/jdbc/dialect/DialectManager.java | 20 +++--- .../amazon/jdbc/dialect/DialectProvider.java | 5 +- .../dialect/HostListProviderSupplier.java | 4 +- .../amazon/jdbc/dialect/MariaDbDialect.java | 4 +- .../amazon/jdbc/dialect/MysqlDialect.java | 4 +- .../amazon/jdbc/dialect/PgDialect.java | 4 +- .../RdsMultiAzDbClusterMysqlDialect.java | 6 +- .../dialect/RdsMultiAzDbClusterPgDialect.java | 6 +- .../amazon/jdbc/dialect/UnknownDialect.java | 4 +- .../AuroraHostListProvider.java | 6 +- .../ConnectionStringHostListProvider.java | 18 ++--- .../hostlistprovider/RdsHostListProvider.java | 14 ++-- .../RdsMultiAzDbClusterListProvider.java | 6 +- .../ClusterTopologyMonitorImpl.java | 2 +- .../MonitoringRdsHostListProvider.java | 14 ++-- .../MonitoringRdsMultiAzHostListProvider.java | 10 +-- .../jdbc/plugin/AbstractConnectionPlugin.java | 8 +-- .../plugin/AuroraConnectionTrackerPlugin.java | 4 +- ...AuroraInitialConnectionStrategyPlugin.java | 8 +-- .../AwsSecretsManagerConnectionPlugin.java | 10 +-- .../plugin/ConnectTimeConnectionPlugin.java | 6 +- .../jdbc/plugin/DefaultConnectionPlugin.java | 20 +++--- .../bluegreen/BlueGreenConnectionPlugin.java | 4 +- .../bluegreen/BlueGreenStatusMonitor.java | 6 +- .../customendpoint/CustomEndpointPlugin.java | 6 +- .../plugin/dev/DeveloperConnectionPlugin.java | 19 +++--- .../ExceptionSimulatorConnectCallback.java | 4 +- .../efm/HostMonitoringConnectionPlugin.java | 4 +- .../plugin/efm2/HostMonitorServiceImpl.java | 2 +- .../efm2/HostMonitoringConnectionPlugin.java | 4 +- .../ClusterAwareReaderFailoverHandler.java | 2 +- .../ClusterAwareWriterFailoverHandler.java | 2 +- .../failover/FailoverConnectionPlugin.java | 8 +-- .../failover2/FailoverConnectionPlugin.java | 12 ++-- .../federatedauth/FederatedAuthPlugin.java | 10 +-- .../plugin/federatedauth/OktaAuthPlugin.java | 10 +-- .../plugin/iam/IamAuthConnectionPlugin.java | 14 ++-- .../limitless/LimitlessConnectionPlugin.java | 6 +- .../limitless/LimitlessRouterServiceImpl.java | 2 +- .../ReadWriteSplittingPlugin.java | 6 +- .../plugin/staledns/AuroraStaleDnsHelper.java | 6 +- .../plugin/staledns/AuroraStaleDnsPlugin.java | 8 +-- .../FastestResponseStrategyPlugin.java | 4 +- .../HostResponseTimeServiceImpl.java | 2 +- .../amazon/jdbc/util/ServiceUtility.java | 8 +-- ...ectionContext.java => ConnectionInfo.java} | 6 +- .../connection/ConnectionServiceImpl.java | 12 ++-- .../jdbc/util/monitoring/MonitorService.java | 6 +- .../util/monitoring/MonitorServiceImpl.java | 11 ++- .../jdbc/wrapper/ConnectionWrapper.java | 14 ++-- .../aurora/TestPluginServiceImpl.java | 6 +- 66 files changed, 343 insertions(+), 346 deletions(-) rename wrapper/src/main/java/software/amazon/jdbc/util/connection/{ConnectionContext.java => ConnectionInfo.java} (92%) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java index 6e94a7b7a..83227da90 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java @@ -36,7 +36,7 @@ import software.amazon.jdbc.OldConnectionSuggestedAction; import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class BenchmarkPlugin implements ConnectionPlugin, CanReleaseResources { final List resources = new ArrayList<>(); @@ -59,7 +59,7 @@ public T execute(Class resultClass, Class excepti @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -70,7 +70,7 @@ public Connection connect( @Override public Connection forceConnect( - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { LOGGER.finer(() -> String.format("forceConnect=''%s''", hostSpec.getHost())); @@ -97,10 +97,10 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin @Override public void initHostProvider( - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, HostListProviderService hostListProviderService, JdbcCallable initHostProviderFunc) { - LOGGER.finer(() -> String.format("initHostProvider=''%s''", connectionContext.getInitialConnectionString())); + LOGGER.finer(() -> String.format("initHostProvider=''%s''", connectionInfo.getInitialConnectionString())); resources.add("initHostProvider"); } diff --git a/docs/development-guide/LoadablePlugins.md b/docs/development-guide/LoadablePlugins.md index 0b0aec2b7..42c4b97e3 100644 --- a/docs/development-guide/LoadablePlugins.md +++ b/docs/development-guide/LoadablePlugins.md @@ -118,7 +118,7 @@ public class BadPlugin extends AbstractConnectionPlugin { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -172,7 +172,7 @@ public class GoodExample extends AbstractConnectionPlugin { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java index eb7971b2d..9f5f4efbc 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java @@ -34,7 +34,7 @@ import software.amazon.jdbc.targetdriverdialect.ConnectInfo; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.SlidingExpirationCache; public class C3P0PooledConnectionProvider implements PooledConnectionProvider, CanReleaseResources { @@ -55,7 +55,7 @@ public class C3P0PooledConnectionProvider implements PooledConnectionProvider, C protected static final long poolExpirationCheckNanos = TimeUnit.MINUTES.toNanos(30); @Override - public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { return true; } @@ -79,14 +79,14 @@ public HostSpec getHostSpecByStrategy(@NonNull List hosts, @NonNull Ho @Override public Connection connect( - @NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) throws SQLException { - Dialect dialect = connectionContext.getDbDialect(); - Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); - dialect.prepareConnectProperties(propsCopy, connectionContext.getProtocol(), hostSpec); + @NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) throws SQLException { + Dialect dialect = connectionInfo.getDbDialect(); + Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); + dialect.prepareConnectProperties(propsCopy, connectionInfo.getProtocol(), hostSpec); final ComboPooledDataSource ds = databasePools.computeIfAbsent( hostSpec.getUrl(), - (key) -> createDataSource(connectionContext, hostSpec, propsCopy), + (key) -> createDataSource(connectionInfo, hostSpec, propsCopy), poolExpirationCheckNanos ); @@ -96,14 +96,14 @@ public Connection connect( } protected ComboPooledDataSource createDataSource( - @NonNull ConnectionContext connectionContext, + @NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec, @NonNull Properties props) { ConnectInfo connectInfo; try { - connectInfo = connectionContext.getDriverDialect() - .prepareConnectInfo(connectionContext.getProtocol(), hostSpec, props); + connectInfo = connectionInfo.getDriverDialect() + .prepareConnectInfo(connectionInfo.getProtocol(), hostSpec, props); } catch (SQLException ex) { throw new RuntimeException(ex); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java index 07f10a783..98ce8bdcb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java @@ -22,7 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; /** * Interface for connection plugins. This class implements ways to execute a JDBC method and to clean up resources used @@ -45,7 +45,7 @@ T execute( * Establishes a connection to the given host using the given driver protocol and properties. If a * non-default {@link ConnectionProvider} has been set with * {@link Driver#setCustomConnectionProvider(ConnectionProvider)} and - * {@link ConnectionProvider#acceptsUrl(ConnectionContext, HostSpec)} returns true for the given + * {@link ConnectionProvider#acceptsUrl(ConnectionInfo, HostSpec)} returns true for the given * protocol, host, and properties, the connection will be created by the non-default * ConnectionProvider. Otherwise, the connection will be created by the default * ConnectionProvider. The default ConnectionProvider will be {@link DriverConnectionProvider} for @@ -53,7 +53,7 @@ T execute( * {@link DataSourceConnectionProvider} for connections requested via an * {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionContext the connection info for the original connection + * @param connectionInfo the connection info for the original connection * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -65,7 +65,7 @@ T execute( * host */ Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) @@ -80,7 +80,7 @@ Connection connect( * requested via the {@link java.sql.DriverManager} and {@link DataSourceConnectionProvider} for * connections requested via an {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -92,7 +92,7 @@ Connection connect( * host */ Connection forceConnect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) @@ -132,7 +132,7 @@ HostSpec getHostSpecByStrategy(final List hosts, final HostRole role, throws SQLException, UnsupportedOperationException; void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException; diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index af9463d23..8c1b64da7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -52,7 +52,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; @@ -371,7 +371,7 @@ public T execute( * Establishes a connection to the given host using the given driver protocol and properties. If a * non-default {@link ConnectionProvider} has been set with * {@link Driver#setCustomConnectionProvider(ConnectionProvider)} and - * {@link ConnectionProvider#acceptsUrl(ConnectionContext, HostSpec)} returns true for the given + * {@link ConnectionProvider#acceptsUrl(ConnectionInfo, HostSpec)} returns true for the given * protocol, host, and properties, the connection will be created by the non-default * ConnectionProvider. Otherwise, the connection will be created by the default * ConnectionProvider. The default ConnectionProvider will be {@link DriverConnectionProvider} for @@ -379,7 +379,7 @@ public T execute( * {@link DataSourceConnectionProvider} for connections requested via an * {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionContext the connection info for the original connection + * @param connectionInfo the connection info for the original connection * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -390,7 +390,7 @@ public T execute( * host */ public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final @Nullable ConnectionPlugin pluginToSkip) @@ -403,7 +403,7 @@ public Connection connect( return executeWithSubscribedPlugins( JdbcMethod.CONNECT, (plugin, func) -> - plugin.connect(connectionContext, hostSpec, isInitialConnection, func), + plugin.connect(connectionInfo, hostSpec, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); }, @@ -428,7 +428,7 @@ public Connection connect( * requested via the {@link java.sql.DriverManager} and {@link DataSourceConnectionProvider} for * connections requested via an {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -439,7 +439,7 @@ public Connection connect( * host */ public Connection forceConnect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final @Nullable ConnectionPlugin pluginToSkip) @@ -449,7 +449,7 @@ public Connection forceConnect( return executeWithSubscribedPlugins( JdbcMethod.FORCECONNECT, (plugin, func) -> - plugin.forceConnect(connectionContext, hostSpec, isInitialConnection, func), + plugin.forceConnect(connectionInfo, hostSpec, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); }, @@ -550,7 +550,7 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin } public void initHostProvider( - final ConnectionContext connectionContext, final HostListProviderService hostListProviderService) + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService) throws SQLException { TelemetryContext context = this.telemetryFactory.openTelemetryContext( "initHostProvider", TelemetryTraceLevel.NESTED); @@ -560,7 +560,7 @@ public void initHostProvider( JdbcMethod.INITHOSTPROVIDER, (PluginPipeline) (plugin, func) -> { - plugin.initHostProvider(connectionContext, hostListProviderService, func); + plugin.initHostProvider(connectionInfo, hostListProviderService, func); return null; }, () -> { diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java index 09f419bd1..46a1ee00b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java @@ -22,7 +22,7 @@ import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; /** * Implement this interface in order to handle the physical connection creation process. @@ -33,12 +33,12 @@ public interface ConnectionProvider { * properties. Some ConnectionProvider implementations may not be able to handle certain URL * types or properties. * - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec the HostSpec containing the host-port information for the host to connect to * @return true if this ConnectionProvider can provide connections for the given URL, otherwise * return false */ - boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec); + boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec); /** * Indicates whether the selection strategy is supported by the connection provider. @@ -70,12 +70,12 @@ HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec the HostSpec containing the host-port information for the host to connect to * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ - Connection connect(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) throws SQLException; + Connection connect(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) throws SQLException; String getTargetName(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java index 89921dfa0..f99f6051b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java @@ -23,7 +23,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.cleanup.CanReleaseResources; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class ConnectionProviderManager { @@ -66,19 +66,19 @@ public static void setConnectionProvider(ConnectionProvider connProvider) { * non-default ConnectionProvider will be returned. Otherwise, the default ConnectionProvider will * be returned. See {@link ConnectionProvider#acceptsUrl} for more info. * - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param host the host info for the connection that will be established * @return the {@link ConnectionProvider} to use to establish a connection using the given driver * protocol, host details, and properties */ - public ConnectionProvider getConnectionProvider(ConnectionContext connectionContext, HostSpec host) { + public ConnectionProvider getConnectionProvider(ConnectionInfo connectionInfo, HostSpec host) { final ConnectionProvider customConnectionProvider = Driver.getCustomConnectionProvider(); - if (customConnectionProvider != null && customConnectionProvider.acceptsUrl(connectionContext, host)) { + if (customConnectionProvider != null && customConnectionProvider.acceptsUrl(connectionInfo, host)) { return customConnectionProvider; } - if (this.effectiveConnProvider != null && this.effectiveConnProvider.acceptsUrl(connectionContext, host)) { + if (this.effectiveConnProvider != null && this.effectiveConnProvider.acceptsUrl(connectionInfo, host)) { return this.effectiveConnProvider; } @@ -207,7 +207,7 @@ public static void resetConnectionInitFunc() { public void initConnection( final @Nullable Connection connection, - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec) throws SQLException { final ConnectionInitFunc connectionInitFunc = Driver.getConnectionInitFunc(); @@ -215,13 +215,13 @@ public void initConnection( return; } - connectionInitFunc.initConnection(connection, connectionContext, hostSpec); + connectionInitFunc.initConnection(connection, connectionInfo, hostSpec); } public interface ConnectionInitFunc { void initConnection( final @Nullable Connection connection, - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec) throws SQLException; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java index 353e0847d..b1342ab3c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java @@ -36,7 +36,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; /** * This class is a basic implementation of {@link ConnectionProvider} interface. It creates and @@ -67,7 +67,7 @@ public DataSourceConnectionProvider(final @NonNull DataSource dataSource) { } @Override - public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { return true; } @@ -93,16 +93,16 @@ public HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec The HostSpec containing the host-port information for the host to connect to * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ @Override public Connection connect( - final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec) throws SQLException { - final Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); - connectionContext.getDbDialect().prepareConnectProperties(propsCopy, connectionContext.getProtocol(), hostSpec); + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec) throws SQLException { + final Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); + connectionInfo.getDbDialect().prepareConnectProperties(propsCopy, connectionInfo.getProtocol(), hostSpec); Connection conn; @@ -111,7 +111,7 @@ public Connection connect( LOGGER.finest(() -> "Use a separate DataSource object to create a connection."); // use a new data source instance to instantiate a connection final DataSource ds = createDataSource(); - conn = this.openConnection(ds, connectionContext, hostSpec, propsCopy); + conn = this.openConnection(ds, connectionInfo, hostSpec, propsCopy); } else { @@ -120,7 +120,7 @@ public Connection connect( this.lock.lock(); LOGGER.finest(() -> "Use main DataSource object to create a connection."); try { - conn = this.openConnection(this.dataSource, connectionContext, hostSpec, propsCopy); + conn = this.openConnection(this.dataSource, connectionInfo, hostSpec, propsCopy); } finally { this.lock.unlock(); } @@ -135,15 +135,15 @@ public Connection connect( protected Connection openConnection( final @NonNull DataSource ds, - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec, final @NonNull Properties props) throws SQLException { final boolean enableGreenNodeReplacement = PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(props); try { - connectionContext.getDriverDialect().prepareDataSource( + connectionInfo.getDriverDialect().prepareDataSource( ds, - connectionContext.getProtocol(), + connectionInfo.getProtocol(), hostSpec, props); return ds.getConnection(); @@ -192,9 +192,9 @@ protected Connection openConnection( .host(fixedHost) .build(); - connectionContext.getDriverDialect().prepareDataSource( + connectionInfo.getDriverDialect().prepareDataSource( this.dataSource, - connectionContext.getProtocol(), + connectionInfo.getProtocol(), connectionHostSpec, props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java index 403ff279e..729bad685 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java @@ -33,7 +33,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; /** * This class is a basic implementation of {@link ConnectionProvider} interface. It creates and @@ -63,7 +63,7 @@ public DriverConnectionProvider(final java.sql.Driver driver) { } @Override - public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { return true; } @@ -89,19 +89,19 @@ public HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec The HostSpec containing the host-port information for the host to connect to * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ @Override - public Connection connect(final @NonNull ConnectionContext connectionContext, final @NonNull HostSpec hostSpec) + public Connection connect(final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec) throws SQLException { - final Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); + final Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); final ConnectInfo connectInfo = - connectionContext.getDriverDialect().prepareConnectInfo(connectionContext.getProtocol(), hostSpec, propsCopy); + connectionInfo.getDriverDialect().prepareConnectInfo(connectionInfo.getProtocol(), hostSpec, propsCopy); - connectionContext.getDbDialect().prepareConnectProperties(propsCopy, connectionContext.getProtocol(), hostSpec); + connectionInfo.getDbDialect().prepareConnectProperties(propsCopy, connectionInfo.getProtocol(), hostSpec); LOGGER.finest(() -> "Connecting to " + connectInfo.url + PropertyUtils.logProperties( PropertyUtils.maskProperties(connectInfo.props), @@ -158,8 +158,8 @@ public Connection connect(final @NonNull ConnectionContext connectionContext, fi .host(fixedHost) .build(); - final ConnectInfo fixedConnectInfo = connectionContext.getDriverDialect().prepareConnectInfo( - connectionContext.getProtocol(), connectionHostSpec, propsCopy); + final ConnectInfo fixedConnectInfo = connectionInfo.getDriverDialect().prepareConnectInfo( + connectionInfo.getProtocol(), connectionHostSpec, propsCopy); LOGGER.finest(() -> "Connecting to " + fixedConnectInfo.url + " after correcting the hostname from " + originalHost diff --git a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java index fe74a6a5e..af0100f54 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.SlidingExpirationCache; public class HikariPooledConnectionProvider implements PooledConnectionProvider, @@ -203,9 +203,9 @@ public HikariPooledConnectionProvider( @Override - public boolean acceptsUrl(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { if (this.acceptsUrlFunc != null) { - return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionContext.getProps()); + return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionInfo.getProps()); } final RdsUrlType urlType = rdsUtils.identifyRdsType(hostSpec.getHost()); @@ -238,9 +238,9 @@ public HostSpec getHostSpecByStrategy( } @Override - public Connection connect(@NonNull ConnectionContext connectionContext, @NonNull HostSpec hostSpec) + public Connection connect(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) throws SQLException { - final Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); + final Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); HostSpec connectionHostSpec = hostSpec; if (PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(propsCopy) @@ -267,12 +267,12 @@ public Connection connect(@NonNull ConnectionContext connectionContext, @NonNull } final HostSpec finalHostSpec = connectionHostSpec; - connectionContext.getDbDialect().prepareConnectProperties( - propsCopy, connectionContext.getProtocol(), finalHostSpec); + connectionInfo.getDbDialect().prepareConnectProperties( + propsCopy, connectionInfo.getProtocol(), finalHostSpec); final HikariDataSource ds = (HikariDataSource) HikariPoolsHolder.databasePools.computeIfAbsent( Pair.create(hostSpec.getUrl(), getPoolKey(finalHostSpec, propsCopy)), - (lambdaPoolKey) -> createHikariDataSource(connectionContext, finalHostSpec, propsCopy), + (lambdaPoolKey) -> createHikariDataSource(connectionInfo, finalHostSpec, propsCopy), poolExpirationCheckNanos ); @@ -306,21 +306,21 @@ public void releaseResources() { * HikariConfig passed to this method should be created via a * {@link HikariPoolConfigurator}, which allows the user to specify any * additional configuration properties. - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec the host details used to form the connection * @param connectionProps the connection properties */ protected void configurePool( final HikariConfig config, - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final Properties connectionProps) { final Properties copy = PropertyUtils.copyProperties(connectionProps); ConnectInfo connectInfo; try { - connectInfo = connectionContext.getDriverDialect().prepareConnectInfo( - connectionContext.getProtocol(), hostSpec, copy); + connectInfo = connectionInfo.getDriverDialect().prepareConnectInfo( + connectionInfo.getProtocol(), hostSpec, copy); } catch (SQLException ex) { throw new RuntimeException(ex); } @@ -406,12 +406,12 @@ public void logConnections() { } HikariDataSource createHikariDataSource( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final Properties props) { HikariConfig config = poolConfigurator.configurePool(hostSpec, props); - configurePool(config, connectionContext, hostSpec, props); + configurePool(config, connectionInfo, hostSpec, props); return new HikariDataSource(config); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 6b8cc318b..6658a2eb2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -48,7 +48,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.CacheMap; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -66,7 +66,7 @@ public class PartialPluginService implements PluginService, CanReleaseResources, protected static final CacheMap hostAvailabilityExpiringCache = new CacheMap<>(); protected final FullServicesContainer servicesContainer; - protected final ConnectionContext connectionContext; + protected final ConnectionInfo connectionInfo; protected final ConnectionPluginManager pluginManager; protected volatile HostListProvider hostListProvider; protected List allHosts = new ArrayList<>(); @@ -79,18 +79,18 @@ public class PartialPluginService implements PluginService, CanReleaseResources, protected final ConnectionProviderManager connectionProviderManager; public PartialPluginService( - @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionContext connectionContext) { + @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionInfo connectionInfo) { this( servicesContainer, new ExceptionManager(), - connectionContext, + connectionInfo, null); } public PartialPluginService( @NonNull final FullServicesContainer servicesContainer, @NonNull final ExceptionManager exceptionManager, - @NonNull final ConnectionContext connectionContext, + @NonNull final ConnectionInfo connectionInfo, @Nullable final ConfigurationProfile configurationProfile) { this.servicesContainer = servicesContainer; this.servicesContainer.setHostListProviderService(this); @@ -98,7 +98,7 @@ public PartialPluginService( this.servicesContainer.setPluginManagerService(this); this.pluginManager = servicesContainer.getConnectionPluginManager(); - this.connectionContext = connectionContext; + this.connectionInfo = connectionInfo; this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; @@ -110,8 +110,8 @@ public PartialPluginService( ? this.configurationProfile.getExceptionHandler() : null; - HostListProviderSupplier supplier = this.connectionContext.getDbDialect().getHostListProvider(); - this.hostListProvider = supplier.getProvider(this.connectionContext, this.servicesContainer); + HostListProviderSupplier supplier = this.connectionInfo.getDbDialect().getHostListProvider(); + this.hostListProvider = supplier.getProvider(this.connectionInfo, this.servicesContainer); } @Override @@ -163,13 +163,13 @@ public HostSpec getInitialConnectionHostSpec() { } @Override - public ConnectionContext getConnectionContext() { - return this.connectionContext; + public ConnectionInfo getConnectionInfo() { + return this.connectionInfo; } @Override public String getOriginalUrl() { - return this.connectionContext.getInitialConnectionString(); + return this.connectionInfo.getInitialConnectionString(); } @Override @@ -214,13 +214,13 @@ public ConnectionProvider getDefaultConnectionProvider() { public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = - this.connectionProviderManager.getConnectionProvider(this.connectionContext, host); + this.connectionProviderManager.getConnectionProvider(this.connectionInfo, host); return (connectionProvider instanceof PooledConnectionProvider); } @Override public String getDriverProtocol() { - return this.connectionContext.getProtocol(); + return this.connectionInfo.getProtocol(); } @Override @@ -504,7 +504,7 @@ public Connection forceConnect( final Properties props, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { - return this.pluginManager.forceConnect(this.connectionContext, hostSpec,true, pluginToSkip); + return this.pluginManager.forceConnect(this.connectionInfo, hostSpec,true, pluginToSkip); } private void updateHostAvailability(final List hosts) { @@ -524,7 +524,7 @@ public void releaseResources() { @Override public boolean isNetworkException(Throwable throwable) { - return this.isNetworkException(throwable, this.connectionContext.getDriverDialect()); + return this.isNetworkException(throwable, this.connectionInfo.getDriverDialect()); } @Override @@ -534,7 +534,7 @@ public boolean isNetworkException(final Throwable throwable, @Nullable TargetDri } return this.exceptionManager.isNetworkException( - this.connectionContext.getDbDialect(), throwable, targetDriverDialect); + this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -543,12 +543,12 @@ public boolean isNetworkException(final String sqlState) { return this.exceptionHandler.isNetworkException(sqlState); } - return this.exceptionManager.isNetworkException(this.connectionContext.getDbDialect(), sqlState); + return this.exceptionManager.isNetworkException(this.connectionInfo.getDbDialect(), sqlState); } @Override public boolean isLoginException(Throwable throwable) { - return this.isLoginException(throwable, this.connectionContext.getDriverDialect()); + return this.isLoginException(throwable, this.connectionInfo.getDriverDialect()); } @Override @@ -558,7 +558,7 @@ public boolean isLoginException(final Throwable throwable, @Nullable TargetDrive } return this.exceptionManager.isLoginException( - this.connectionContext.getDbDialect(), throwable, targetDriverDialect); + this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -566,17 +566,17 @@ public boolean isLoginException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(sqlState); } - return this.exceptionManager.isLoginException(this.connectionContext.getDbDialect(), sqlState); + return this.exceptionManager.isLoginException(this.connectionInfo.getDbDialect(), sqlState); } @Override public Dialect getDialect() { - return this.connectionContext.getDbDialect(); + return this.connectionInfo.getDbDialect(); } @Override public TargetDriverDialect getTargetDriverDialect() { - return this.connectionContext.getDriverDialect(); + return this.connectionInfo.getDriverDialect(); } @Override @@ -624,12 +624,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getProps())); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionInfo.getProps())); } @Override public Properties getProperties() { - return this.connectionContext.getProps(); + return this.connectionInfo.getProps(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index e1af0469e..ffc7a62a5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -29,7 +29,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.states.SessionStateService; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryFactory; /** @@ -83,10 +83,10 @@ EnumSet setCurrentConnection( HostSpec getInitialConnectionHostSpec(); /** - * Get the {@link ConnectionContext} for the current original connection. - * @return the {@link ConnectionContext} for the current original connection. + * Get the {@link ConnectionInfo} for the current original connection. + * @return the {@link ConnectionInfo} for the current original connection. */ - ConnectionContext getConnectionContext(); + ConnectionInfo getConnectionInfo(); String getOriginalUrl(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 82772eefb..0e5c4d267 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -53,7 +53,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.CacheMap; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -70,7 +70,7 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, protected static final long DEFAULT_STATUS_CACHE_EXPIRE_NANO = TimeUnit.MINUTES.toNanos(60); protected final ConnectionPluginManager pluginManager; - protected final ConnectionContext connectionContext; + protected final ConnectionInfo connectionInfo; protected volatile HostListProvider hostListProvider; protected List allHosts = new ArrayList<>(); protected Connection currentConnection; @@ -88,19 +88,19 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, protected final ReentrantLock connectionSwitchLock = new ReentrantLock(); public PluginServiceImpl( - @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionContext connectionContext) + @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionInfo connectionInfo) throws SQLException { - this(servicesContainer, new ExceptionManager(), connectionContext, null, null, null); + this(servicesContainer, new ExceptionManager(), connectionInfo, null, null, null); } public PluginServiceImpl( @NonNull final FullServicesContainer servicesContainer, - @NonNull final ConnectionContext connectionContext, + @NonNull final ConnectionInfo connectionInfo, @Nullable final ConfigurationProfile configurationProfile) throws SQLException { this( servicesContainer, new ExceptionManager(), - connectionContext, + connectionInfo, null, configurationProfile, null); @@ -109,13 +109,13 @@ public PluginServiceImpl( public PluginServiceImpl( @NonNull final FullServicesContainer servicesContainer, @NonNull final ExceptionManager exceptionManager, - @NonNull final ConnectionContext connectionContext, + @NonNull final ConnectionInfo connectionInfo, @Nullable final DialectProvider dialectProvider, @Nullable final ConfigurationProfile configurationProfile, @Nullable final SessionStateService sessionStateService) throws SQLException { this.servicesContainer = servicesContainer; this.pluginManager = servicesContainer.getConnectionPluginManager(); - this.connectionContext = connectionContext; + this.connectionInfo = connectionInfo; this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; this.dialectProvider = dialectProvider != null ? dialectProvider : new DialectManager(this); @@ -125,7 +125,7 @@ public PluginServiceImpl( this.sessionStateService = sessionStateService != null ? sessionStateService - : new SessionStateServiceImpl(this, this.connectionContext.getProps()); + : new SessionStateServiceImpl(this, this.connectionInfo.getProps()); this.exceptionHandler = this.configurationProfile != null && this.configurationProfile.getExceptionHandler() != null ? this.configurationProfile.getExceptionHandler() @@ -133,8 +133,8 @@ public PluginServiceImpl( Dialect dialect = this.configurationProfile != null && this.configurationProfile.getDialect() != null ? this.configurationProfile.getDialect() - : this.dialectProvider.getDialect(this.connectionContext); - this.connectionContext.setDbDialect(dialect); + : this.dialectProvider.getDialect(this.connectionInfo); + this.connectionInfo.setDbDialect(dialect); } @Override @@ -185,13 +185,13 @@ public HostSpec getInitialConnectionHostSpec() { } @Override - public ConnectionContext getConnectionContext() { - return this.connectionContext; + public ConnectionInfo getConnectionInfo() { + return this.connectionInfo; } @Override public String getOriginalUrl() { - return this.connectionContext.getInitialConnectionString(); + return this.connectionInfo.getInitialConnectionString(); } @Override @@ -233,13 +233,13 @@ public ConnectionProvider getDefaultConnectionProvider() { public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = - this.connectionProviderManager.getConnectionProvider(this.connectionContext, host); + this.connectionProviderManager.getConnectionProvider(this.connectionInfo, host); return (connectionProvider instanceof PooledConnectionProvider); } @Override public String getDriverProtocol() { - return this.getConnectionContext().getProtocol(); + return this.getConnectionInfo().getProtocol(); } @Override @@ -291,7 +291,7 @@ public EnumSet setCurrentConnection( this.setInTransaction(false); if (isInTransaction - && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionContext.getProps())) { + && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionInfo.getProps())) { try { oldConnection.rollback(); } catch (final SQLException e) { @@ -592,7 +592,7 @@ public Connection connect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.connect( - this.connectionContext, hostSpec, this.currentConnection == null, pluginToSkip); + this.connectionInfo, hostSpec, this.currentConnection == null, pluginToSkip); } @Override @@ -610,7 +610,7 @@ public Connection forceConnect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.forceConnect( - this.connectionContext, hostSpec, this.currentConnection == null, pluginToSkip); + this.connectionInfo, hostSpec, this.currentConnection == null, pluginToSkip); } private void updateHostAvailability(final List hosts) { @@ -643,7 +643,7 @@ public void releaseResources() { @Override @Deprecated public boolean isNetworkException(Throwable throwable) { - return this.isNetworkException(throwable, this.connectionContext.getDriverDialect()); + return this.isNetworkException(throwable, this.connectionInfo.getDriverDialect()); } @Override @@ -651,7 +651,7 @@ public boolean isNetworkException(final Throwable throwable, @Nullable TargetDri if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(throwable, targetDriverDialect); } - return this.exceptionManager.isNetworkException(this.connectionContext.getDbDialect(), throwable, targetDriverDialect); + return this.exceptionManager.isNetworkException(this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -659,13 +659,13 @@ public boolean isNetworkException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(sqlState); } - return this.exceptionManager.isNetworkException(this.connectionContext.getDbDialect(), sqlState); + return this.exceptionManager.isNetworkException(this.connectionInfo.getDbDialect(), sqlState); } @Override @Deprecated public boolean isLoginException(Throwable throwable) { - return this.isLoginException(throwable, this.connectionContext.getDriverDialect()); + return this.isLoginException(throwable, this.connectionInfo.getDriverDialect()); } @Override @@ -673,7 +673,7 @@ public boolean isLoginException(final Throwable throwable, @Nullable TargetDrive if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(throwable, targetDriverDialect); } - return this.exceptionManager.isLoginException(this.connectionContext.getDbDialect(), throwable, targetDriverDialect); + return this.exceptionManager.isLoginException(this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -681,30 +681,30 @@ public boolean isLoginException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(sqlState); } - return this.exceptionManager.isLoginException(this.connectionContext.getDbDialect(), sqlState); + return this.exceptionManager.isLoginException(this.connectionInfo.getDbDialect(), sqlState); } @Override public Dialect getDialect() { - return this.connectionContext.getDbDialect(); + return this.connectionInfo.getDbDialect(); } @Override public TargetDriverDialect getTargetDriverDialect() { - return this.connectionContext.getDriverDialect(); + return this.connectionInfo.getDriverDialect(); } public void updateDialect(final @NonNull Connection connection) throws SQLException { - final Dialect originalDialect = this.connectionContext.getDbDialect(); + final Dialect originalDialect = this.connectionInfo.getDbDialect(); Dialect currentDialect = this.dialectProvider.getDialect( - this.connectionContext.getProtocol(), this.initialConnectionHostSpec, connection); + this.connectionInfo.getProtocol(), this.initialConnectionHostSpec, connection); if (originalDialect == currentDialect) { return; } - this.connectionContext.setDbDialect(currentDialect); - final HostListProviderSupplier supplier = this.connectionContext.getDbDialect().getHostListProvider(); - this.setHostListProvider(supplier.getProvider(this.connectionContext, this.servicesContainer)); + this.connectionInfo.setDbDialect(currentDialect); + final HostListProviderSupplier supplier = this.connectionInfo.getDbDialect().getHostListProvider(); + this.setHostListProvider(supplier.getProvider(this.connectionInfo, this.servicesContainer)); this.refreshHostList(connection); } @@ -747,12 +747,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionContext.getProps())); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionInfo.getProps())); } @Override public Properties getProperties() { - return this.connectionContext.getProps(); + return this.connectionInfo.getProps(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java index ede526f9b..5d3675674 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java @@ -89,11 +89,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> { + return (connectionInfo, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsHostListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -102,7 +102,7 @@ public HostListProviderSupplier getHostListProvider() { } return new AuroraHostListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java index e00cdc3df..34d589c2a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java @@ -135,11 +135,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> { + return (connectionInfo, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsHostListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -148,7 +148,7 @@ public HostListProviderSupplier getHostListProvider() { } return new AuroraHostListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java index 771ed8c1d..075ca0687 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java @@ -35,7 +35,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.CacheMap; public class DialectManager implements DialectProvider { @@ -119,7 +119,7 @@ public static void resetEndpointCache() { } @Override - public Dialect getDialect(final @NonNull ConnectionContext connectionContext) throws SQLException { + public Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws SQLException { this.canUpdate = false; this.dialect = null; @@ -131,10 +131,10 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th return this.dialect; } - final String userDialectSetting = DIALECT.getString(connectionContext.getProps()); + final String userDialectSetting = DIALECT.getString(connectionInfo.getProps()); final String dialectCode = !StringUtils.isNullOrEmpty(userDialectSetting) ? userDialectSetting - : knownEndpointDialects.get(connectionContext.getInitialConnectionString()); + : knownEndpointDialects.get(connectionInfo.getInitialConnectionString()); if (!StringUtils.isNullOrEmpty(dialectCode)) { final Dialect userDialect = knownDialectsByCode.get(dialectCode); @@ -149,18 +149,18 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th } } - if (StringUtils.isNullOrEmpty(connectionContext.getProtocol())) { + if (StringUtils.isNullOrEmpty(connectionInfo.getProtocol())) { throw new IllegalArgumentException("protocol"); } - String connectionString = connectionContext.getInitialConnectionString(); + String connectionString = connectionInfo.getInitialConnectionString(); final List hosts = this.connectionUrlParser.getHostsFromConnectionUrl( - connectionContext.getInitialConnectionString(), true, pluginService::getHostSpecBuilder); + connectionInfo.getInitialConnectionString(), true, pluginService::getHostSpecBuilder); if (!Utils.isNullOrEmpty(hosts)) { connectionString = hosts.get(0).getHost(); } - if (connectionContext.getProtocol().contains("mysql")) { + if (connectionInfo.getProtocol().contains("mysql")) { RdsUrlType type = this.rdsHelper.identifyRdsType(connectionString); if (type.isRdsCluster()) { this.canUpdate = true; @@ -182,7 +182,7 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th return this.dialect; } - if (connectionContext.getProtocol().contains("postgresql")) { + if (connectionInfo.getProtocol().contains("postgresql")) { RdsUrlType type = this.rdsHelper.identifyRdsType(connectionString); if (RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP.equals(type)) { this.canUpdate = false; @@ -210,7 +210,7 @@ public Dialect getDialect(final @NonNull ConnectionContext connectionContext) th return this.dialect; } - if (connectionContext.getProtocol().contains("mariadb")) { + if (connectionInfo.getProtocol().contains("mariadb")) { this.canUpdate = true; this.dialectCode = DialectCodes.MARIADB; this.dialect = knownDialectsByCode.get(DialectCodes.MARIADB); diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java index ed0f4cae7..384aeb22d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java @@ -18,13 +18,12 @@ import java.sql.Connection; import java.sql.SQLException; -import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public interface DialectProvider { - Dialect getDialect(final @NonNull ConnectionContext connectionContext) throws SQLException; + Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws SQLException; Dialect getDialect( final @NonNull String originalUrl, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java index 4515d6285..ad6ddbc19 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java @@ -19,11 +19,11 @@ import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostListProvider; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; @FunctionalInterface public interface HostListProviderSupplier { @NonNull HostListProvider getProvider( - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull FullServicesContainer servicesContainer); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java index 52e104c07..4bd593c34 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java @@ -104,8 +104,8 @@ public List getDialectUpdateCandidates() { } public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> - new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); + return (connectionInfo, servicesContainer) -> + new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java index 26043aa41..db5e7556c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java @@ -105,8 +105,8 @@ public List getDialectUpdateCandidates() { } public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> - new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); + return (connectionInfo, servicesContainer) -> + new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java index 300436604..ebd48c64d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java @@ -106,8 +106,8 @@ public List getDialectUpdateCandidates() { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> - new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); + return (connectionInfo, servicesContainer) -> + new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java index 756591474..7ac9bcc10 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java @@ -94,11 +94,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> { + return (connectionInfo, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsMultiAzHostListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -108,7 +108,7 @@ public HostListProviderSupplier getHostListProvider() { } else { return new RdsMultiAzDbClusterListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java index 23e9d6b78..fcf94826e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java @@ -80,11 +80,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> { + return (connectionInfo, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsMultiAzHostListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -95,7 +95,7 @@ public HostListProviderSupplier getHostListProvider() { } else { return new RdsMultiAzDbClusterListProvider( - connectionContext, + connectionInfo, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java index e1ecc937e..7350a52ca 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java @@ -81,8 +81,8 @@ public List getDialectUpdateCandidates() { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionContext, servicesContainer) -> - new ConnectionStringHostListProvider(connectionContext, servicesContainer.getHostListProviderService()); + return (connectionInfo, servicesContainer) -> + new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java index ab1bac786..12cb429e1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java @@ -19,7 +19,7 @@ import java.util.logging.Logger; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class AuroraHostListProvider extends RdsHostListProvider { @@ -27,13 +27,13 @@ public class AuroraHostListProvider extends RdsHostListProvider { static final Logger LOGGER = Logger.getLogger(AuroraHostListProvider.class.getName()); public AuroraHostListProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery) { super( - connectionContext, + connectionInfo, servicesContainer, topologyQuery, nodeIdQuery, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java index 4311917f2..09239d6d2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java @@ -29,7 +29,7 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.util.ConnectionUrlParser; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class ConnectionStringHostListProvider implements StaticHostListProvider { @@ -39,7 +39,7 @@ public class ConnectionStringHostListProvider implements StaticHostListProvider private boolean isInitialized = false; private final boolean isSingleWriterConnectionString; private final ConnectionUrlParser connectionUrlParser; - private final ConnectionContext connectionContext; + private final ConnectionInfo connectionInfo; private final HostListProviderService hostListProviderService; public static final AwsWrapperProperty SINGLE_WRITER_CONNECTION_STRING = @@ -50,17 +50,17 @@ public class ConnectionStringHostListProvider implements StaticHostListProvider + "cluster has only one writer. The writer must be the first host in the connection string"); public ConnectionStringHostListProvider( - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostListProviderService hostListProviderService) { - this(connectionContext, hostListProviderService, new ConnectionUrlParser()); + this(connectionInfo, hostListProviderService, new ConnectionUrlParser()); } ConnectionStringHostListProvider( - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostListProviderService hostListProviderService, final @NonNull ConnectionUrlParser connectionUrlParser) { - this.connectionContext = connectionContext; - this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionContext.getProps()); + this.connectionInfo = connectionInfo; + this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionInfo.getProps()); this.connectionUrlParser = connectionUrlParser; this.hostListProviderService = hostListProviderService; } @@ -71,12 +71,12 @@ private void init() throws SQLException { } this.hostList.addAll( this.connectionUrlParser.getHostsFromConnectionUrl( - this.connectionContext.getInitialConnectionString(), + this.connectionInfo.getInitialConnectionString(), this.isSingleWriterConnectionString, () -> this.hostListProviderService.getHostSpecBuilder())); if (this.hostList.isEmpty()) { throw new SQLException(Messages.get("ConnectionStringHostListProvider.parsedListEmpty", - new Object[] {this.connectionContext.getInitialConnectionString()})); + new Object[] {this.connectionInfo.getInitialConnectionString()})); } this.hostListProviderService.setInitialConnectionHostSpec(this.hostList.get(0)); this.isInitialized = true; diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java index 6e043631f..97800cbe1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java @@ -55,7 +55,7 @@ import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.SynchronousExecutor; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.CacheMap; public class RdsHostListProvider implements DynamicHostListProvider { @@ -95,7 +95,7 @@ public class RdsHostListProvider implements DynamicHostListProvider { protected final FullServicesContainer servicesContainer; protected final HostListProviderService hostListProviderService; - protected final ConnectionContext connectionContext; + protected final ConnectionInfo connectionInfo; protected final String topologyQuery; protected final String nodeIdQuery; protected final String isReaderQuery; @@ -122,12 +122,12 @@ public class RdsHostListProvider implements DynamicHostListProvider { } public RdsHostListProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery) { - this.connectionContext = connectionContext; + this.connectionInfo = connectionInfo; this.servicesContainer = servicesContainer; this.hostListProviderService = servicesContainer.getHostListProviderService(); this.topologyQuery = topologyQuery; @@ -148,18 +148,18 @@ protected void init() throws SQLException { // initial topology is based on connection string this.initialHostList = - connectionUrlParser.getHostsFromConnectionUrl(this.connectionContext.getInitialConnectionString(), false, + connectionUrlParser.getHostsFromConnectionUrl(this.connectionInfo.getInitialConnectionString(), false, this.hostListProviderService::getHostSpecBuilder); if (this.initialHostList == null || this.initialHostList.isEmpty()) { throw new SQLException(Messages.get("RdsHostListProvider.parsedListEmpty", - new Object[] {this.connectionContext.getInitialConnectionString()})); + new Object[] {this.connectionInfo.getInitialConnectionString()})); } this.initialHostSpec = this.initialHostList.get(0); this.hostListProviderService.setInitialConnectionHostSpec(this.initialHostSpec); this.clusterId = UUID.randomUUID().toString(); this.isPrimaryClusterId = false; - Properties props = this.connectionContext.getProps(); + Properties props = this.connectionInfo.getProps(); this.refreshRateNano = TimeUnit.MILLISECONDS.toNanos(CLUSTER_TOPOLOGY_REFRESH_RATE_MS.getInteger(props)); diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java index 663b024f2..a32748ef1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java @@ -32,7 +32,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class RdsMultiAzDbClusterListProvider extends RdsHostListProvider { private final String fetchWriterNodeQuery; @@ -40,7 +40,7 @@ public class RdsMultiAzDbClusterListProvider extends RdsHostListProvider { static final Logger LOGGER = Logger.getLogger(RdsMultiAzDbClusterListProvider.class.getName()); public RdsMultiAzDbClusterListProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, @@ -49,7 +49,7 @@ public RdsMultiAzDbClusterListProvider( final String fetchWriterNodeQueryHeader ) { super( - connectionContext, + connectionInfo, servicesContainer, topologyQuery, nodeIdQuery, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java index d2aacbab1..4ce035c6d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java @@ -510,7 +510,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.servicesContainer.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.servicesContainer.getPluginService().getConnectionContext() + this.servicesContainer.getPluginService().getConnectionInfo() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java index 6eb96d5e6..fea674ae1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java @@ -31,7 +31,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.Topology; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; @@ -56,18 +56,18 @@ public class MonitoringRdsHostListProvider extends RdsHostListProvider protected final String writerTopologyQuery; public MonitoringRdsHostListProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery, final String writerTopologyQuery) { - super(connectionContext, servicesContainer, topologyQuery, nodeIdQuery, isReaderQuery); + super(connectionInfo, servicesContainer, topologyQuery, nodeIdQuery, isReaderQuery); this.servicesContainer = servicesContainer; this.pluginService = servicesContainer.getPluginService(); this.writerTopologyQuery = writerTopologyQuery; this.highRefreshRateNano = TimeUnit.MILLISECONDS.toNanos( - CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.getLong(this.connectionContext.getProps())); + CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.getLong(this.connectionInfo.getProps())); } public static void clearCache() { @@ -86,12 +86,12 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.connectionContext, + this.connectionInfo, (servicesContainer) -> new ClusterTopologyMonitorImpl( this.servicesContainer, this.clusterId, this.initialHostSpec, - this.connectionContext.getProps(), + this.connectionInfo.getProps(), this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, @@ -127,7 +127,7 @@ protected void clusterIdChanged(final String oldClusterId) throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.connectionContext, + this.connectionInfo, (servicesContainer) -> existingMonitor); assert monitorService.get(ClusterTopologyMonitorImpl.class, this.clusterId) == existingMonitor; existingMonitor.setClusterId(this.clusterId); diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java index c8626e344..30c2c77da 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java @@ -19,7 +19,7 @@ import java.sql.SQLException; import java.util.logging.Logger; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class MonitoringRdsMultiAzHostListProvider extends MonitoringRdsHostListProvider { @@ -29,7 +29,7 @@ public class MonitoringRdsMultiAzHostListProvider extends MonitoringRdsHostListP protected final String fetchWriterNodeColumnName; public MonitoringRdsMultiAzHostListProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, @@ -37,7 +37,7 @@ public MonitoringRdsMultiAzHostListProvider( final String fetchWriterNodeQuery, final String fetchWriterNodeColumnName) { super( - connectionContext, + connectionInfo, servicesContainer, topologyQuery, nodeIdQuery, @@ -55,12 +55,12 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.connectionContext, + this.connectionInfo, (servicesContainer) -> new MultiAzClusterTopologyMonitorImpl( servicesContainer, this.clusterId, this.initialHostSpec, - this.connectionContext.getProps(), + this.connectionInfo.getProps(), this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java index 0d704c53d..2a0561898 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java @@ -29,7 +29,7 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.OldConnectionSuggestedAction; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public abstract class AbstractConnectionPlugin implements ConnectionPlugin { @@ -50,7 +50,7 @@ public T execute( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) @@ -60,7 +60,7 @@ public Connection connect( @Override public Connection forceConnect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) @@ -87,7 +87,7 @@ public HostSpec getHostSpecByStrategy(final List hosts, final HostRole @Override public void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java index 19454b941..43796f425 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java @@ -36,7 +36,7 @@ import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin { @@ -84,7 +84,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java index 610ae776a..388d993ac 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java @@ -39,7 +39,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlugin { @@ -120,7 +120,7 @@ public Set getSubscribedMethods() { @Override public void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -133,13 +133,13 @@ public void initHostProvider( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { final RdsUrlType type = this.rdsUtils.identifyRdsType(hostSpec.getHost()); - final Properties props = connectionContext.getProps(); + final Properties props = connectionInfo.getProps(); if (type == RdsUrlType.RDS_WRITER_CLUSTER || isInitialConnection && this.verifyOpenedConnectionType == VerifyOpenedConnectionType.WRITER) { Connection writerCandidateConn = this.getVerifiedWriterConnection(props, isInitialConnection, connectFunc); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java index b4a815ef6..ec5e24688 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java @@ -50,7 +50,7 @@ import software.amazon.jdbc.util.Pair; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -185,12 +185,12 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getProps(), connectFunc); + return connectInternal(hostSpec, connectionInfo.getProps(), connectFunc); } private Connection connectInternal(HostSpec hostSpec, Properties props, @@ -226,12 +226,12 @@ private Connection connectInternal(HostSpec hostSpec, Properties props, @Override public Connection forceConnect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getProps(), forceConnectFunc); + return connectInternal(hostSpec, connectionInfo.getProps(), forceConnectFunc); } /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java index f5b071d44..313aeb666 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java @@ -26,7 +26,7 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class ConnectTimeConnectionPlugin extends AbstractConnectionPlugin { @@ -44,7 +44,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { @@ -63,7 +63,7 @@ public Connection connect( @Override public Connection forceConnect( - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java index 7253df836..ed51292a3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java @@ -45,7 +45,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlMethodAnalyzer; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; @@ -154,20 +154,20 @@ public T execute( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - ConnectionProvider connProvider = this.connProviderManager.getConnectionProvider(connectionContext, hostSpec); + ConnectionProvider connProvider = this.connProviderManager.getConnectionProvider(connectionInfo, hostSpec); // It's guaranteed that this plugin is always the last in plugin chain so connectFunc can be // ignored. - return connectInternal(connectionContext, hostSpec, connProvider, isInitialConnection); + return connectInternal(connectionInfo, hostSpec, connProvider, isInitialConnection); } private Connection connectInternal( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final ConnectionProvider connProvider, final boolean isInitialConnection) throws SQLException { @@ -177,14 +177,14 @@ private Connection connectInternal( Connection conn; try { - conn = connProvider.connect(connectionContext, hostSpec); + conn = connProvider.connect(connectionInfo, hostSpec); } finally { if (telemetryContext != null) { telemetryContext.closeContext(); } } - this.connProviderManager.initConnection(conn, connectionContext, hostSpec); + this.connProviderManager.initConnection(conn, connectionInfo, hostSpec); this.pluginService.setAvailability(hostSpec.asAliases(), HostAvailability.AVAILABLE); if (isInitialConnection) { @@ -196,7 +196,7 @@ private Connection connectInternal( @Override public Connection forceConnect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) @@ -204,7 +204,7 @@ public Connection forceConnect( // It's guaranteed that this plugin is always the last in plugin chain so forceConnectFunc can be // ignored. - return connectInternal(connectionContext, hostSpec, this.defaultConnProvider, isInitialConnection); + return connectInternal(connectionInfo, hostSpec, this.defaultConnProvider, isInitialConnection); } @Override @@ -242,7 +242,7 @@ public HostSpec getHostSpecByStrategy(final List hosts, final HostRole @Override public void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java index be72e8b32..c63e3740c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java @@ -42,7 +42,7 @@ import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -127,7 +127,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java index fea69250b..69db5d091 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java @@ -57,7 +57,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class BlueGreenStatusMonitor { @@ -599,7 +599,7 @@ protected void initHostListProvider() { return; } - final ConnectionContext originalContext = this.pluginService.getConnectionContext(); + final ConnectionInfo originalContext = this.pluginService.getConnectionInfo(); final Properties hostListProperties = originalContext.getProps(); // Need to instantiate a separate HostListProvider with @@ -616,7 +616,7 @@ protected void initHostListProvider() { if (connectionHostSpecCopy != null) { String hostListProviderUrl = String.format("%s%s/", originalContext.getProtocol(), connectionHostSpecCopy.getHostAndPort()); - ConnectionContext newContext = new ConnectionContext( + ConnectionInfo newContext = new ConnectionInfo( hostListProviderUrl, originalContext.getProtocol(), originalContext.getDriverDialect(), hostListProperties); this.hostListProvider = this.pluginService.getDialect().getHostListProvider().getProvider(newContext, this.servicesContainer); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java index 1ade11648..42ab0e189 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.monitoring.MonitorErrorResponse; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -164,7 +164,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -217,7 +217,7 @@ protected CustomEndpointMonitor createMonitorIfAbsent(Properties props) throws S this.servicesContainer.getStorageService(), this.pluginService.getTelemetryFactory(), this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getConnectionContext(), + this.pluginService.getConnectionInfo(), (servicesContainer) -> new CustomEndpointMonitorImpl( servicesContainer.getStorageService(), servicesContainer.getTelemetryFactory(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java index 80b61d8c4..ba62818f8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java @@ -20,7 +20,6 @@ import java.sql.SQLException; import java.util.Collections; import java.util.HashSet; -import java.util.Properties; import java.util.Set; import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; @@ -29,7 +28,7 @@ import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class DeveloperConnectionPlugin extends AbstractConnectionPlugin implements ExceptionSimulator { @@ -143,28 +142,28 @@ protected void raiseException( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - this.raiseExceptionOnConnectIfNeeded(connectionContext, hostSpec, isInitialConnection); - return super.connect(connectionContext, hostSpec, isInitialConnection, connectFunc); + this.raiseExceptionOnConnectIfNeeded(connectionInfo, hostSpec, isInitialConnection); + return super.connect(connectionInfo, hostSpec, isInitialConnection, connectFunc); } @Override public Connection forceConnect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { - this.raiseExceptionOnConnectIfNeeded(connectionContext, hostSpec, isInitialConnection); - return super.connect(connectionContext, hostSpec, isInitialConnection, forceConnectFunc); + this.raiseExceptionOnConnectIfNeeded(connectionInfo, hostSpec, isInitialConnection); + return super.connect(connectionInfo, hostSpec, isInitialConnection, forceConnectFunc); } protected void raiseExceptionOnConnectIfNeeded( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection) throws SQLException { @@ -174,7 +173,7 @@ protected void raiseExceptionOnConnectIfNeeded( } else if (ExceptionSimulatorManager.connectCallback != null) { this.raiseExceptionOnConnect( ExceptionSimulatorManager.connectCallback.getExceptionToRaise( - connectionContext, hostSpec, isInitialConnection)); + connectionInfo, hostSpec, isInitialConnection)); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java index 39f6a2453..28f8a6fe7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java @@ -18,11 +18,11 @@ import java.sql.SQLException; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public interface ExceptionSimulatorConnectCallback { SQLException getExceptionToRaise( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java index 493d2fe60..c0359c4e7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; /** * Monitor the server while the connection is executing methods for more sophisticated failure @@ -270,7 +270,7 @@ public OldConnectionSuggestedAction notifyConnectionChanged(final EnumSet connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java index 8e661d23b..7879a7f40 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java @@ -156,7 +156,7 @@ protected HostMonitor getMonitor( this.serviceContainer.getStorageService(), this.telemetryFactory, this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getConnectionContext(), + this.pluginService.getConnectionInfo(), (servicesContainer) -> new HostMonitorImpl( servicesContainer, hostSpec, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java index 698349a25..134234656 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; /** * Monitor the server while the connection is executing methods for more sophisticated failure @@ -228,7 +228,7 @@ public OldConnectionSuggestedAction notifyConnectionChanged(final EnumSet connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 38f1333e5..19c986edb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -364,7 +364,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.pluginService.getConnectionContext() + this.pluginService.getConnectionInfo() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index 81cb525af..d171387cf 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -174,7 +174,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.pluginService.getConnectionContext() + this.pluginService.getConnectionInfo() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 987f09a91..d7a2b88c2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -53,7 +53,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -302,7 +302,7 @@ public T execute( @Override public void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -904,7 +904,7 @@ private boolean canDirectExecute(final String methodName) { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -913,7 +913,7 @@ public Connection connect( Connection conn = null; try { conn = this.staleDnsHelper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); } catch (final SQLException e) { if (!this.enableConnectFailover || !shouldExceptionTriggerConnectionSwitch(e)) { throw e; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java index de45eeef4..4c831043d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java @@ -53,7 +53,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -233,7 +233,7 @@ public T execute( @Override public void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -727,17 +727,17 @@ protected void initFailoverMode() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { this.initFailoverMode(); Connection conn = null; - Properties props = connectionContext.getProps(); + Properties props = connectionInfo.getProps(); if (!ENABLE_CONNECT_FAILOVER.getBoolean(props)) { return this.staleDnsHelper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); } final HostSpec hostSpecWithAvailability = this.pluginService.getHosts().stream() @@ -750,7 +750,7 @@ public Connection connect( try { conn = this.staleDnsHelper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); } catch (final SQLException e) { if (!this.shouldExceptionTriggerConnectionSwitch(e)) { throw e; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java index d1bbab792..b45e06c8b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -42,7 +42,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -153,21 +153,21 @@ public FederatedAuthPlugin(final PluginService pluginService, @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getProps(), connectFunc); + return connectInternal(hostSpec, connectionInfo.getProps(), connectFunc); } @Override public Connection forceConnect( - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec, final boolean isInitialConnection, final @NonNull JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getProps(), forceConnectFunc); + return connectInternal(hostSpec, connectionInfo.getProps(), forceConnectFunc); } private Connection connectInternal( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java index 510c508cf..2a80104c7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -134,21 +134,21 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getProps(), connectFunc); + return connectInternal(hostSpec, connectionInfo.getProps(), connectFunc); } @Override public Connection forceConnect( - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionContext.getProps(), forceConnectFunc); + return connectInternal(hostSpec, connectionInfo.getProps(), forceConnectFunc); } private Connection connectInternal(final HostSpec hostSpec, final Properties props, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java index eec9d7fc9..d314d6edf 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -107,18 +107,18 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(connectionContext, hostSpec, connectFunc); + return connectInternal(connectionInfo, hostSpec, connectFunc); } private Connection connectInternal( - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, HostSpec hostSpec, JdbcCallable connectFunc) throws SQLException { - Properties props = connectionContext.getProps(); + Properties props = connectionInfo.getProps(); if (StringUtils.isNullOrEmpty(PropertyDefinition.USER.getString(props))) { throw new SQLException(PropertyDefinition.USER.name + " is null or empty."); } @@ -226,12 +226,12 @@ private Connection connectInternal( @Override public Connection forceConnect( - final @NonNull ConnectionContext connectionContext, + final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec, final boolean isInitialConnection, final @NonNull JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(connectionContext, hostSpec, forceConnectFunc); + return connectInternal(connectionInfo, hostSpec, forceConnectFunc); } public static void clearCache() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java index 8284dc086..63bcf3094 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java @@ -35,7 +35,7 @@ import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class LimitlessConnectionPlugin extends AbstractConnectionPlugin { @@ -104,7 +104,7 @@ public LimitlessConnectionPlugin( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -130,7 +130,7 @@ public Connection connect( final LimitlessConnectionContext context = new LimitlessConnectionContext( hostSpec, - connectionContext.getProps(), + connectionInfo.getProps(), conn, connectFunc, null, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java index db2a27bb6..0ae98fed6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java @@ -323,7 +323,7 @@ public void startMonitoring(final @NonNull HostSpec hostSpec, final @NonNull Pro this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getConnectionContext(), + this.pluginService.getConnectionInfo(), (servicesContainer) -> new LimitlessRouterMonitor( servicesContainer, hostSpec, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index e4cdcc760..2c3482f12 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -44,7 +44,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implements CanReleaseResources { @@ -125,7 +125,7 @@ public Set getSubscribedMethods() { @Override public void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -135,7 +135,7 @@ public void initHostProvider( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java index 65e31b694..e82f65214 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java @@ -33,7 +33,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -61,7 +61,7 @@ public AuroraStaleDnsHelper(final PluginService pluginService) { public Connection getVerifiedConnection( final boolean isInitialConnection, final HostListProviderService hostListProviderService, - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final JdbcCallable connectFunc) throws SQLException { @@ -147,7 +147,7 @@ public Connection getVerifiedConnection( ); } - final Connection writerConn = this.pluginService.connect(this.writerHostSpec, connectionContext.getProps()); + final Connection writerConn = this.pluginService.connect(this.writerHostSpec, connectionInfo.getProps()); if (isInitialConnection) { hostListProviderService.setInitialConnectionHostSpec(this.writerHostSpec); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java index fb8eca976..c1f622b0b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java @@ -32,7 +32,7 @@ import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; /** * After Aurora DB cluster fail over is completed and a cluster has elected a new writer node, the corresponding @@ -76,17 +76,17 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { return this.helper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionContext, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); } @Override public void initHostProvider( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { this.hostListProviderService = hostListProviderService; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java index e35963955..ae06497a9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.RandomHostSelector; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.CacheMap; public class FastestResponseStrategyPlugin extends AbstractConnectionPlugin { @@ -112,7 +112,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java index 6e15bfe0c..bbd31ec1d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java @@ -79,7 +79,7 @@ public void setHosts(final @NonNull List hosts) { servicesContainer.getStorageService(), servicesContainer.getTelemetryFactory(), servicesContainer.getDefaultConnectionProvider(), - this.pluginService.getConnectionContext(), + this.pluginService.getConnectionInfo(), (servicesContainer) -> new NodeResponseTimeMonitor(pluginService, hostSpec, this.props, this.intervalMs)); } catch (SQLException e) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java index a76432c2a..1364f3c2f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java @@ -22,7 +22,7 @@ import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.PartialPluginService; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -59,20 +59,20 @@ public FullServicesContainer createServiceContainer( MonitorService monitorService, ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, - ConnectionContext connectionContext) throws SQLException { + ConnectionInfo connectionInfo) throws SQLException { FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, connectionProvider, telemetryFactory); ConnectionPluginManager pluginManager = new ConnectionPluginManager( connectionProvider, null, null, telemetryFactory); servicesContainer.setConnectionPluginManager(pluginManager); - PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, connectionContext); + PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, connectionInfo); servicesContainer.setHostListProviderService(partialPluginService); servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); - Properties propsCopy = PropertyUtils.copyProperties(connectionContext.getProps()); + Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); pluginManager.init(servicesContainer, propsCopy, partialPluginService, null); return servicesContainer; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionInfo.java similarity index 92% rename from wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java rename to wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionInfo.java index a06b25802..75dbc7a0f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionContext.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionInfo.java @@ -21,7 +21,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.ConnectionUrlParser; -public class ConnectionContext { +public class ConnectionInfo { protected static final ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); protected final String initialConnectionString; protected final String protocol; @@ -29,11 +29,11 @@ public class ConnectionContext { protected final Properties props; protected Dialect dbDialect; - public ConnectionContext(String initialConnectionString, TargetDriverDialect driverDialect, Properties props) { + public ConnectionInfo(String initialConnectionString, TargetDriverDialect driverDialect, Properties props) { this(initialConnectionString, connectionUrlParser.getProtocol(initialConnectionString), driverDialect, props); } - public ConnectionContext( + public ConnectionInfo( String initialConnectionString, String protocol, TargetDriverDialect driverDialect, Properties props) { this.initialConnectionString = initialConnectionString; this.protocol = protocol; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 306bf7303..6e29ba7a7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -39,7 +39,7 @@ */ @Deprecated public class ConnectionServiceImpl implements ConnectionService { - protected final ConnectionContext connectionContext; + protected final ConnectionInfo connectionInfo; protected final ConnectionPluginManager pluginManager; protected final PluginService pluginService; @@ -54,8 +54,8 @@ public ConnectionServiceImpl( MonitorService monitorService, TelemetryFactory telemetryFactory, ConnectionProvider connectionProvider, - ConnectionContext connectionContext) throws SQLException { - this.connectionContext = connectionContext; + ConnectionInfo connectionInfo) throws SQLException { + this.connectionInfo = connectionInfo; FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, connectionProvider, telemetryFactory); @@ -65,20 +65,20 @@ public ConnectionServiceImpl( null, telemetryFactory); servicesContainer.setConnectionPluginManager(this.pluginManager); - PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, this.connectionContext); + PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, this.connectionInfo); servicesContainer.setHostListProviderService(partialPluginService); servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); this.pluginService = partialPluginService; - this.pluginManager.init(servicesContainer, this.connectionContext.getProps(), partialPluginService, null); + this.pluginManager.init(servicesContainer, this.connectionInfo.getProps(), partialPluginService, null); } @Override @Deprecated public Connection open(HostSpec hostSpec, Properties props) throws SQLException { - return this.pluginManager.forceConnect(this.connectionContext, hostSpec, true, null); + return this.pluginManager.forceConnect(this.connectionInfo, hostSpec, true, null); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java index 975fa3f8d..019a742a2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java @@ -20,7 +20,7 @@ import java.util.Set; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -59,7 +59,7 @@ void registerMonitorTypeIfAbsent( * @param telemetryFactory the telemetry factory for creating telemetry data. * @param defaultConnectionProvider the connection provider to use to create new connections if the monitor * requires it. - * @param connectionContext the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param initializer an initializer function to use to create the monitor if it does not already exist. * @param the type of the monitor. * @return the new or existing monitor. @@ -71,7 +71,7 @@ T runIfAbsent( StorageService storageService, TelemetryFactory telemetryFactory, ConnectionProvider defaultConnectionProvider, - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, MonitorInitializer initializer) throws SQLException; /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index 018f66120..d00ab68d1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -23,7 +23,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledExecutorService; @@ -41,7 +40,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.ServiceUtility; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.events.DataAccessEvent; import software.amazon.jdbc.util.events.Event; import software.amazon.jdbc.util.events.EventPublisher; @@ -180,7 +179,7 @@ public T runIfAbsent( StorageService storageService, TelemetryFactory telemetryFactory, ConnectionProvider defaultConnectionProvider, - ConnectionContext connectionContext, + ConnectionInfo connectionInfo, MonitorInitializer initializer) throws SQLException { CacheContainer cacheContainer = monitorCaches.get(monitorClass); if (cacheContainer == null) { @@ -202,7 +201,7 @@ public T runIfAbsent( storageService, defaultConnectionProvider, telemetryFactory, - connectionContext); + connectionInfo); final MonitorItem monitorItemInner = new MonitorItem(() -> initializer.createMonitor(servicesContainer)); monitorItemInner.getMonitor().start(); return monitorItemInner; @@ -229,13 +228,13 @@ protected FullServicesContainer getNewServicesContainer( StorageService storageService, ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, - ConnectionContext connectionContext) throws SQLException { + ConnectionInfo connectionInfo) throws SQLException { return ServiceUtility.getInstance().createServiceContainer( storageService, this, connectionProvider, telemetryFactory, - connectionContext + connectionInfo ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java index e01fa6ea6..a938728f8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java @@ -57,7 +57,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -66,7 +66,7 @@ public class ConnectionWrapper implements Connection, CanReleaseResources { private static final Logger LOGGER = Logger.getLogger(ConnectionWrapper.class.getName()); - protected ConnectionContext connectionContext; + protected ConnectionInfo connectionInfo; protected ConnectionPluginManager pluginManager; protected TelemetryFactory telemetryFactory; protected PluginService pluginService; @@ -91,7 +91,7 @@ public ConnectionWrapper( throw new IllegalArgumentException("url"); } - this.connectionContext = new ConnectionContext(url, driverDialect, props); + this.connectionInfo = new ConnectionInfo(url, driverDialect, props); this.configurationProfile = configurationProfile; final ConnectionPluginManager pluginManager = @@ -103,7 +103,7 @@ public ConnectionWrapper( servicesContainer.setConnectionPluginManager(pluginManager); final PluginServiceImpl pluginService = - new PluginServiceImpl(servicesContainer, this.connectionContext, this.configurationProfile); + new PluginServiceImpl(servicesContainer, this.connectionInfo, this.configurationProfile); servicesContainer.setHostListProviderService(pluginService); servicesContainer.setPluginService(pluginService); servicesContainer.setPluginManagerService(pluginService); @@ -159,15 +159,15 @@ protected void init(final Properties props, final FullServicesContainer services final HostListProviderSupplier supplier = this.pluginService.getDialect().getHostListProvider(); if (supplier != null) { - final HostListProvider provider = supplier.getProvider(this.connectionContext, servicesContainer); + final HostListProvider provider = supplier.getProvider(this.connectionInfo, servicesContainer); hostListProviderService.setHostListProvider(provider); } - this.pluginManager.initHostProvider(this.connectionContext, this.hostListProviderService); + this.pluginManager.initHostProvider(this.connectionInfo, this.hostListProviderService); this.pluginService.refreshHostList(); if (this.pluginService.getCurrentConnection() == null) { final Connection conn = this.pluginManager.connect( - this.connectionContext, this.pluginService.getInitialConnectionHostSpec(), true, null); + this.connectionInfo, this.pluginService.getInitialConnectionHostSpec(), true, null); if (conn == null) { throw new SQLException(Messages.get("ConnectionWrapper.connectionNotOpen"), SqlState.UNKNOWN_STATE.getState()); } diff --git a/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java b/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java index 1d597570a..7d96686eb 100644 --- a/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java +++ b/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java @@ -20,14 +20,14 @@ import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.PluginServiceImpl; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class TestPluginServiceImpl extends PluginServiceImpl { public TestPluginServiceImpl( - @NonNull FullServicesContainer servicesContainer, @NonNull ConnectionContext connectionContext) + @NonNull FullServicesContainer servicesContainer, @NonNull ConnectionInfo connectionInfo) throws SQLException { - super(servicesContainer, connectionContext); + super(servicesContainer, connectionInfo); } public static void clearHostAvailabilityCache() { From d81dfe348b1a6ab300e12b558faf327315451449 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 22 Sep 2025 14:47:14 -0700 Subject: [PATCH 48/54] fix: store connection string instead of protocol in knownEndpointDialects --- .../src/main/java/software/amazon/jdbc/PluginServiceImpl.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 0e5c4d267..3a1779b2a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -697,7 +697,7 @@ public TargetDriverDialect getTargetDriverDialect() { public void updateDialect(final @NonNull Connection connection) throws SQLException { final Dialect originalDialect = this.connectionInfo.getDbDialect(); Dialect currentDialect = this.dialectProvider.getDialect( - this.connectionInfo.getProtocol(), this.initialConnectionHostSpec, connection); + this.connectionInfo.getInitialConnectionString(), this.initialConnectionHostSpec, connection); if (originalDialect == currentDialect) { return; } From bb1ef95a11cb12ad07761706ca87a3f6b0c8c277 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 24 Sep 2025 14:20:36 -0700 Subject: [PATCH 49/54] Uncomment tests --- .../ConnectionPluginManagerBenchmarks.java | 556 ++--- .../jdbc/benchmarks/PluginBenchmarks.java | 694 +++--- .../aurora/TestAuroraHostListProvider.java | 66 +- .../tests/AdvancedPerformanceTest.java | 1534 ++++++------- .../jdbc/ConnectionPluginManagerTests.java | 1940 ++++++++--------- .../amazon/jdbc/DialectDetectionTests.java | 578 ++--- .../HikariPooledConnectionProviderTest.java | 494 ++--- .../amazon/jdbc/PluginServiceImplTests.java | 1892 ++++++++-------- .../RdsHostListProviderTest.java | 1258 +++++------ .../RdsMultiAzDbClusterListProviderTest.java | 940 ++++---- .../amazon/jdbc/mock/TestPluginOne.java | 310 +-- .../amazon/jdbc/mock/TestPluginThree.java | 174 +- .../jdbc/mock/TestPluginThrowException.java | 224 +- .../amazon/jdbc/mock/TestPluginTwo.java | 66 +- .../AuroraConnectionTrackerPluginTest.java | 490 ++--- ...AwsSecretsManagerConnectionPluginTest.java | 1032 ++++----- .../plugin/DefaultConnectionPluginTest.java | 276 +-- .../CustomEndpointPluginTest.java | 318 +-- .../dev/DeveloperConnectionPluginTest.java | 712 +++--- .../HostMonitoringConnectionPluginTest.java | 688 +++--- .../FederatedAuthPluginTest.java | 450 ++-- .../federatedauth/OktaAuthPluginTest.java | 440 ++-- .../iam/IamAuthConnectionPluginTest.java | 590 ++--- .../LimitlessConnectionPluginTest.java | 324 +-- .../ReadWriteSplittingPluginTest.java | 1252 +++++------ .../monitoring/MonitorServiceImplTest.java | 630 +++--- 26 files changed, 8964 insertions(+), 8964 deletions(-) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java index 773bd83ce..8f14ad5b3 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java @@ -1,278 +1,278 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.benchmarks; -// -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.when; -// import static org.mockito.MockitoAnnotations.openMocks; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.ArrayList; -// import java.util.Collections; -// import java.util.EnumSet; -// import java.util.List; -// import java.util.Properties; -// import java.util.concurrent.TimeUnit; -// import org.mockito.Mock; -// import org.openjdk.jmh.annotations.Benchmark; -// import org.openjdk.jmh.annotations.BenchmarkMode; -// import org.openjdk.jmh.annotations.Fork; -// import org.openjdk.jmh.annotations.Level; -// import org.openjdk.jmh.annotations.Measurement; -// import org.openjdk.jmh.annotations.Mode; -// import org.openjdk.jmh.annotations.OutputTimeUnit; -// import org.openjdk.jmh.annotations.Scope; -// import org.openjdk.jmh.annotations.Setup; -// import org.openjdk.jmh.annotations.State; -// import org.openjdk.jmh.annotations.TearDown; -// import org.openjdk.jmh.annotations.Warmup; -// import org.openjdk.jmh.profile.GCProfiler; -// import org.openjdk.jmh.runner.Runner; -// import org.openjdk.jmh.runner.RunnerException; -// import org.openjdk.jmh.runner.options.Options; -// import org.openjdk.jmh.runner.options.OptionsBuilder; -// import software.amazon.jdbc.ConnectionPluginFactory; -// import software.amazon.jdbc.ConnectionPluginManager; -// import software.amazon.jdbc.ConnectionProvider; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcMethod; -// import software.amazon.jdbc.NodeChangeOptions; -// import software.amazon.jdbc.OldConnectionSuggestedAction; -// import software.amazon.jdbc.PluginManagerService; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.benchmarks.testplugin.BenchmarkPluginFactory; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.profile.ConfigurationProfile; -// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.connection.ConnectionContext; -// import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; -// import software.amazon.jdbc.util.telemetry.GaugeCallable; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.util.telemetry.TelemetryGauge; -// import software.amazon.jdbc.wrapper.ConnectionWrapper; -// -// @State(Scope.Benchmark) -// @Fork(3) -// @Warmup(iterations = 3) -// @Measurement(iterations = 10) -// @BenchmarkMode(Mode.SingleShotTime) -// @OutputTimeUnit(TimeUnit.NANOSECONDS) -// public class ConnectionPluginManagerBenchmarks { -// -// private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; -// private static final String FIELD_SERVER_ID = "SERVER_ID"; -// private static final String FIELD_SESSION_ID = "SESSION_ID"; -// private static final String url = "protocol//url"; -// private ConnectionContext pluginsContext; -// private ConnectionContext noPluginsContext; -// private ConnectionPluginManager pluginManager; -// private ConnectionPluginManager pluginManagerWithNoPlugins; -// -// @Mock ConnectionProvider mockConnectionProvider; -// @Mock ConnectionWrapper mockConnectionWrapper; -// @Mock FullServicesContainer mockServicesContainer; -// @Mock PluginService mockPluginService; -// @Mock PluginManagerService mockPluginManagerService; -// @Mock TargetDriverDialect mockDriverDialect; -// @Mock TelemetryFactory mockTelemetryFactory; -// @Mock HostListProviderService mockHostListProvider; -// @Mock Connection mockConnection; -// @Mock Statement mockStatement; -// @Mock ResultSet mockResultSet; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock TelemetryCounter mockTelemetryCounter; -// @Mock TelemetryGauge mockTelemetryGauge; -// ConfigurationProfile configurationProfile; -// private AutoCloseable closeable; -// -// public static void main(String[] args) throws RunnerException { -// Options opt = new OptionsBuilder() -// .include(software.amazon.jdbc.benchmarks.PluginBenchmarks.class.getSimpleName()) -// .addProfiler(GCProfiler.class) -// .detectJvmArgs() -// .build(); -// -// new Runner(opt).run(); -// } -// -// @Setup(Level.Iteration) -// public void setUpIteration() throws Exception { -// closeable = openMocks(this); -// -// when(mockConnectionProvider.connect(any(), any(HostSpec.class))).thenReturn(mockConnection); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); -// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); -// when(mockConnection.createStatement()).thenReturn(mockStatement); -// when(mockConnection.createStatement()).thenReturn(mockStatement); -// when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); -// when(mockResultSet.next()).thenReturn(true, true, false); -// when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); -// when(mockResultSet.getString(eq(FIELD_SERVER_ID))) -// .thenReturn("myInstance1.domain.com", "myInstance2.domain.com", "myInstance3.domain.com"); -// when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// -// // Create a plugin chain with 10 custom test plugins. -// final List> pluginFactories = new ArrayList<>( -// Collections.nCopies(10, BenchmarkPluginFactory.class)); -// -// configurationProfile = ConfigurationProfileBuilder.get() -// .withName("benchmark") -// .withPluginFactories(pluginFactories) -// .build(); -// -// Properties noPluginsProps = new Properties(); -// noPluginsProps.setProperty(PropertyDefinition.PLUGINS.name, ""); -// this.noPluginsContext = new ConnectionContext(url, mockDriverDialect, noPluginsProps); -// -// Properties pluginsProps = new Properties(); -// pluginsProps.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); -// pluginsProps.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); -// this.pluginsContext = new ConnectionContext(url, mockDriverDialect, pluginsProps); -// -// TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(pluginsProps); -// -// pluginManager = new ConnectionPluginManager(mockConnectionProvider, -// null, -// mockConnectionWrapper, -// telemetryFactory); -// pluginManager.init(mockServicesContainer, pluginsProps, mockPluginManagerService, configurationProfile); -// -// pluginManagerWithNoPlugins = new ConnectionPluginManager(mockConnectionProvider, null, -// mockConnectionWrapper, telemetryFactory); -// pluginManagerWithNoPlugins.init(mockServicesContainer, noPluginsProps, mockPluginManagerService, null); -// } -// -// @TearDown(Level.Iteration) -// public void tearDownIteration() throws Exception { -// closeable.close(); -// } -// -// @Benchmark -// public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws SQLException { -// final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, -// mockConnectionWrapper, mockTelemetryFactory); -// manager.init(mockServicesContainer, this.noPluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); -// return manager; -// } -// -// @Benchmark -// public ConnectionPluginManager initConnectionPluginManagerWithPlugins() throws SQLException { -// final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, -// mockConnectionWrapper, mockTelemetryFactory); -// manager.init(mockServicesContainer, this.pluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); -// return manager; -// } -// -// @Benchmark -// public Connection connectWithPlugins() throws SQLException { -// return pluginManager.connect( -// "driverProtocol", -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), -// this.pluginsContext.getPropsCopy(), -// true, -// null); -// } -// -// @Benchmark -// public Connection connectWithNoPlugins() throws SQLException { -// return pluginManagerWithNoPlugins.connect( -// "driverProtocol", -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), -// this.noPluginsContext.getPropsCopy(), -// true, -// null); -// } -// -// @Benchmark -// public Integer executeWithPlugins() { -// return pluginManager.execute( -// int.class, -// RuntimeException.class, -// mockStatement, -// JdbcMethod.STATEMENT_EXECUTE, -// () -> 1, -// new Object[] {1} -// ); -// } -// -// @Benchmark -// public Integer executeWithNoPlugins() { -// return pluginManagerWithNoPlugins.execute( -// int.class, -// RuntimeException.class, -// mockStatement, -// JdbcMethod.STATEMENT_EXECUTE, -// () -> 1, -// new Object[] {1} -// ); -// } -// -// @Benchmark -// public ConnectionPluginManager initHostProvidersWithPlugins() throws SQLException { -// pluginManager.initHostProvider(this.pluginsContext, mockHostListProvider); -// return pluginManager; -// } -// -// @Benchmark -// public ConnectionPluginManager initHostProvidersWithNoPlugins() throws SQLException { -// pluginManagerWithNoPlugins.initHostProvider(this.noPluginsContext, mockHostListProvider); -// return pluginManager; -// } -// -// @Benchmark -// public EnumSet notifyConnectionChangedWithPlugins() { -// return pluginManager.notifyConnectionChanged( -// EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), -// null); -// } -// -// @Benchmark -// public EnumSet notifyConnectionChangedWithNoPlugins() { -// return pluginManagerWithNoPlugins.notifyConnectionChanged( -// EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), -// null); -// } -// -// @Benchmark -// public ConnectionPluginManager releaseResourcesWithPlugins() { -// pluginManager.releaseResources(); -// return pluginManager; -// } -// -// @Benchmark -// public ConnectionPluginManager releaseResourcesWithNoPlugins() { -// pluginManagerWithNoPlugins.releaseResources(); -// return pluginManager; -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.benchmarks; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.mockito.Mock; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.profile.GCProfiler; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.NodeChangeOptions; +import software.amazon.jdbc.OldConnectionSuggestedAction; +import software.amazon.jdbc.PluginManagerService; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.benchmarks.testplugin.BenchmarkPluginFactory; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.profile.ConfigurationProfile; +import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; +import software.amazon.jdbc.util.telemetry.GaugeCallable; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; +import software.amazon.jdbc.wrapper.ConnectionWrapper; + +@State(Scope.Benchmark) +@Fork(3) +@Warmup(iterations = 3) +@Measurement(iterations = 10) +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class ConnectionPluginManagerBenchmarks { + + private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; + private static final String FIELD_SERVER_ID = "SERVER_ID"; + private static final String FIELD_SESSION_ID = "SESSION_ID"; + private static final String url = "protocol//url"; + private ConnectionContext pluginsContext; + private ConnectionContext noPluginsContext; + private ConnectionPluginManager pluginManager; + private ConnectionPluginManager pluginManagerWithNoPlugins; + + @Mock ConnectionProvider mockConnectionProvider; + @Mock ConnectionWrapper mockConnectionWrapper; + @Mock FullServicesContainer mockServicesContainer; + @Mock PluginService mockPluginService; + @Mock PluginManagerService mockPluginManagerService; + @Mock TargetDriverDialect mockDriverDialect; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock HostListProviderService mockHostListProvider; + @Mock Connection mockConnection; + @Mock Statement mockStatement; + @Mock ResultSet mockResultSet; + @Mock TelemetryContext mockTelemetryContext; + @Mock TelemetryCounter mockTelemetryCounter; + @Mock TelemetryGauge mockTelemetryGauge; + ConfigurationProfile configurationProfile; + private AutoCloseable closeable; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(software.amazon.jdbc.benchmarks.PluginBenchmarks.class.getSimpleName()) + .addProfiler(GCProfiler.class) + .detectJvmArgs() + .build(); + + new Runner(opt).run(); + } + + @Setup(Level.Iteration) + public void setUpIteration() throws Exception { + closeable = openMocks(this); + + when(mockConnectionProvider.connect(any(), any(HostSpec.class))).thenReturn(mockConnection); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); + when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true, false); + when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); + when(mockResultSet.getString(eq(FIELD_SERVER_ID))) + .thenReturn("myInstance1.domain.com", "myInstance2.domain.com", "myInstance3.domain.com"); + when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + + // Create a plugin chain with 10 custom test plugins. + final List> pluginFactories = new ArrayList<>( + Collections.nCopies(10, BenchmarkPluginFactory.class)); + + configurationProfile = ConfigurationProfileBuilder.get() + .withName("benchmark") + .withPluginFactories(pluginFactories) + .build(); + + Properties noPluginsProps = new Properties(); + noPluginsProps.setProperty(PropertyDefinition.PLUGINS.name, ""); + this.noPluginsContext = new ConnectionContext(url, mockDriverDialect, noPluginsProps); + + Properties pluginsProps = new Properties(); + pluginsProps.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); + pluginsProps.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); + this.pluginsContext = new ConnectionContext(url, mockDriverDialect, pluginsProps); + + TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(pluginsProps); + + pluginManager = new ConnectionPluginManager(mockConnectionProvider, + null, + mockConnectionWrapper, + telemetryFactory); + pluginManager.init(mockServicesContainer, pluginsProps, mockPluginManagerService, configurationProfile); + + pluginManagerWithNoPlugins = new ConnectionPluginManager(mockConnectionProvider, null, + mockConnectionWrapper, telemetryFactory); + pluginManagerWithNoPlugins.init(mockServicesContainer, noPluginsProps, mockPluginManagerService, null); + } + + @TearDown(Level.Iteration) + public void tearDownIteration() throws Exception { + closeable.close(); + } + + @Benchmark + public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws SQLException { + final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, + mockConnectionWrapper, mockTelemetryFactory); + manager.init(mockServicesContainer, this.noPluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); + return manager; + } + + @Benchmark + public ConnectionPluginManager initConnectionPluginManagerWithPlugins() throws SQLException { + final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, + mockConnectionWrapper, mockTelemetryFactory); + manager.init(mockServicesContainer, this.pluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); + return manager; + } + + @Benchmark + public Connection connectWithPlugins() throws SQLException { + return pluginManager.connect( + "driverProtocol", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), + this.pluginsContext.getPropsCopy(), + true, + null); + } + + @Benchmark + public Connection connectWithNoPlugins() throws SQLException { + return pluginManagerWithNoPlugins.connect( + "driverProtocol", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), + this.noPluginsContext.getPropsCopy(), + true, + null); + } + + @Benchmark + public Integer executeWithPlugins() { + return pluginManager.execute( + int.class, + RuntimeException.class, + mockStatement, + JdbcMethod.STATEMENT_EXECUTE, + () -> 1, + new Object[] {1} + ); + } + + @Benchmark + public Integer executeWithNoPlugins() { + return pluginManagerWithNoPlugins.execute( + int.class, + RuntimeException.class, + mockStatement, + JdbcMethod.STATEMENT_EXECUTE, + () -> 1, + new Object[] {1} + ); + } + + @Benchmark + public ConnectionPluginManager initHostProvidersWithPlugins() throws SQLException { + pluginManager.initHostProvider(this.pluginsContext, mockHostListProvider); + return pluginManager; + } + + @Benchmark + public ConnectionPluginManager initHostProvidersWithNoPlugins() throws SQLException { + pluginManagerWithNoPlugins.initHostProvider(this.noPluginsContext, mockHostListProvider); + return pluginManager; + } + + @Benchmark + public EnumSet notifyConnectionChangedWithPlugins() { + return pluginManager.notifyConnectionChanged( + EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), + null); + } + + @Benchmark + public EnumSet notifyConnectionChangedWithNoPlugins() { + return pluginManagerWithNoPlugins.notifyConnectionChanged( + EnumSet.of(NodeChangeOptions.INITIAL_CONNECTION), + null); + } + + @Benchmark + public ConnectionPluginManager releaseResourcesWithPlugins() { + pluginManager.releaseResources(); + return pluginManager; + } + + @Benchmark + public ConnectionPluginManager releaseResourcesWithNoPlugins() { + pluginManagerWithNoPlugins.releaseResources(); + return pluginManager; + } +} diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java index 36c31c61f..35932705c 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java @@ -1,347 +1,347 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.benchmarks; -// -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyBoolean; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.when; -// -// import com.zaxxer.hikari.HikariConfig; -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.Properties; -// import java.util.concurrent.TimeUnit; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import org.openjdk.jmh.annotations.Benchmark; -// import org.openjdk.jmh.annotations.BenchmarkMode; -// import org.openjdk.jmh.annotations.Fork; -// import org.openjdk.jmh.annotations.Level; -// import org.openjdk.jmh.annotations.Measurement; -// import org.openjdk.jmh.annotations.Mode; -// import org.openjdk.jmh.annotations.OutputTimeUnit; -// import org.openjdk.jmh.annotations.Scope; -// import org.openjdk.jmh.annotations.Setup; -// import org.openjdk.jmh.annotations.State; -// import org.openjdk.jmh.annotations.TearDown; -// import org.openjdk.jmh.annotations.Warmup; -// import org.openjdk.jmh.profile.GCProfiler; -// import org.openjdk.jmh.runner.Runner; -// import org.openjdk.jmh.runner.RunnerException; -// import org.openjdk.jmh.runner.options.Options; -// import org.openjdk.jmh.runner.options.OptionsBuilder; -// import software.amazon.jdbc.ConnectionPluginManager; -// import software.amazon.jdbc.ConnectionProvider; -// import software.amazon.jdbc.ConnectionProviderManager; -// import software.amazon.jdbc.Driver; -// import software.amazon.jdbc.HikariPooledConnectionProvider; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcMethod; -// import software.amazon.jdbc.PluginManagerService; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.benchmarks.testplugin.TestConnectionWrapper; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.monitoring.MonitorService; -// import software.amazon.jdbc.util.storage.StorageService; -// import software.amazon.jdbc.util.telemetry.GaugeCallable; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.util.telemetry.TelemetryGauge; -// import software.amazon.jdbc.wrapper.ConnectionWrapper; -// -// @State(Scope.Benchmark) -// @Fork(3) -// @Warmup(iterations = 3) -// @Measurement(iterations = 10) -// @BenchmarkMode(Mode.SingleShotTime) -// @OutputTimeUnit(TimeUnit.NANOSECONDS) -// public class PluginBenchmarks { -// -// private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; -// private static final String FIELD_SERVER_ID = "SERVER_ID"; -// private static final String FIELD_SESSION_ID = "SESSION_ID"; -// private static final String CONNECTION_STRING = "jdbc:postgresql://my.domain.com"; -// private static final String PG_CONNECTION_STRING = -// "jdbc:aws-wrapper:postgresql://instance-0.XYZ.us-east-2.rds.amazonaws.com"; -// private static final String TEST_HOST = "instance-0"; -// private static final int TEST_PORT = 5432; -// private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host(TEST_HOST).port(TEST_PORT).build(); -// -// @Mock private StorageService mockStorageService; -// @Mock private MonitorService mockMonitorService; -// @Mock private PluginService mockPluginService; -// @Mock private TargetDriverDialect mockTargetDriverDialect; -// @Mock private Dialect mockDialect; -// @Mock private ConnectionPluginManager mockConnectionPluginManager; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock TelemetryCounter mockTelemetryCounter; -// @Mock TelemetryGauge mockTelemetryGauge; -// @Mock private HostListProviderService mockHostListProviderService; -// @Mock private PluginManagerService mockPluginManagerService; -// @Mock ConnectionProvider mockConnectionProvider; -// @Mock Connection mockConnection; -// @Mock Statement mockStatement; -// @Mock ResultSet mockResultSet; -// private AutoCloseable closeable; -// -// public static void main(String[] args) throws RunnerException { -// Options opt = new OptionsBuilder() -// .include(PluginBenchmarks.class.getSimpleName()) -// .addProfiler(GCProfiler.class) -// .detectJvmArgs() -// .build(); -// -// new Runner(opt).run(); -// } -// -// @Setup(Level.Iteration) -// public void setUpIteration() throws Exception { -// closeable = MockitoAnnotations.openMocks(this); -// when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean(), any())) -// .thenReturn(mockConnection); -// when(mockConnectionPluginManager.execute( -// any(), any(), any(), eq(JdbcMethod.CONNECTION_CREATESTATEMENT), any(), any())) -// .thenReturn(mockStatement); -// when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); -// // noinspection unchecked -// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); -// when(mockConnectionProvider.connect( -// anyString(), -// any(Dialect.class), -// any(TargetDriverDialect.class), -// any(HostSpec.class), -// any(Properties.class))).thenReturn(mockConnection); -// when(mockConnection.createStatement()).thenReturn(mockStatement); -// when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); -// when(mockResultSet.next()).thenReturn(true, true, false); -// when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); -// when(mockResultSet.getString(eq(FIELD_SERVER_ID))) -// .thenReturn("instance-0", "instance-1"); -// when(mockResultSet.getStatement()).thenReturn(mockStatement); -// when(mockStatement.getConnection()).thenReturn(mockConnection); -// when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); -// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); -// when(this.mockPluginService.getDialect()).thenReturn(mockDialect); -// } -// -// @TearDown(Level.Iteration) -// public void tearDownIteration() throws Exception { -// closeable.close(); -// } -// -// @Benchmark -// public void initAndReleaseBaseLine() { -// } -// -// @Benchmark -// public ConnectionWrapper initAndReleaseWithExecutionTimePlugin() throws SQLException { -// try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING)) { -// wrapper.releaseResources(); -// return wrapper; -// } -// } -// -// private ConnectionWrapper getConnectionWrapper(Properties props, String connString) throws SQLException { -// return new TestConnectionWrapper( -// props, -// connString, -// mockConnectionProvider, -// mockTargetDriverDialect, -// mockConnectionPluginManager, -// mockTelemetryFactory, -// mockPluginService, -// mockHostListProviderService, -// mockPluginManagerService, -// mockStorageService, -// mockMonitorService); -// } -// -// @Benchmark -// public ConnectionWrapper initAndReleaseWithAuroraHostListPlugin() throws SQLException { -// try (ConnectionWrapper wrapper = getConnectionWrapper(useAuroraHostListPlugin(), CONNECTION_STRING)) { -// wrapper.releaseResources(); -// return wrapper; -// } -// } -// -// @Benchmark -// public ConnectionWrapper initAndReleaseWithExecutionTimeAndAuroraHostListPlugins() throws SQLException { -// try (ConnectionWrapper wrapper = -// getConnectionWrapper(useExecutionTimeAndAuroraHostListPlugins(), CONNECTION_STRING)) { -// wrapper.releaseResources(); -// return wrapper; -// } -// } -// -// @Benchmark -// public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin() throws SQLException { -// try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { -// wrapper.releaseResources(); -// return wrapper; -// } -// } -// -// @Benchmark -// public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin() -// throws SQLException { -// try (ConnectionWrapper wrapper = -// getConnectionWrapper(useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { -// wrapper.releaseResources(); -// return wrapper; -// } -// } -// -// @Benchmark -// public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin_internalConnectionPools() throws SQLException { -// HikariPooledConnectionProvider provider = -// new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); -// Driver.setCustomConnectionProvider(provider); -// try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { -// wrapper.releaseResources(); -// ConnectionProviderManager.releaseResources(); -// Driver.resetCustomConnectionProvider(); -// return wrapper; -// } -// } -// -// @Benchmark -// public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin_internalConnectionPools() -// throws SQLException { -// HikariPooledConnectionProvider provider = -// new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); -// Driver.setCustomConnectionProvider(provider); -// try (ConnectionWrapper wrapper = getConnectionWrapper( -// useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { -// wrapper.releaseResources(); -// ConnectionProviderManager.releaseResources(); -// Driver.resetCustomConnectionProvider(); -// return wrapper; -// } -// } -// -// @Benchmark -// public Statement executeStatementBaseline() throws SQLException { -// try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); -// Statement statement = wrapper.createStatement()) { -// return statement; -// } -// } -// -// @Benchmark -// public ResultSet executeStatementWithExecutionTimePlugin() throws SQLException { -// try ( -// ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); -// Statement statement = wrapper.createStatement(); -// ResultSet resultSet = statement.executeQuery("some sql")) { -// return resultSet; -// } -// } -// -// @Benchmark -// public ResultSet executeStatementWithTelemetryDisabled() throws SQLException { -// try ( -// ConnectionWrapper wrapper = getConnectionWrapper(disabledTelemetry(), CONNECTION_STRING); -// Statement statement = wrapper.createStatement(); -// ResultSet resultSet = statement.executeQuery("some sql")) { -// return resultSet; -// } -// } -// -// @Benchmark -// public ResultSet executeStatementWithTelemetry() throws SQLException { -// try ( -// ConnectionWrapper wrapper = getConnectionWrapper(useTelemetry(), CONNECTION_STRING); -// Statement statement = wrapper.createStatement(); -// ResultSet resultSet = statement.executeQuery("some sql")) { -// return resultSet; -// } -// } -// -// Properties useExecutionTimePlugin() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "executionTime"); -// return properties; -// } -// -// Properties useAuroraHostListPlugin() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "auroraHostList"); -// return properties; -// } -// -// Properties useExecutionTimeAndAuroraHostListPlugins() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "executionTime,auroraHostList"); -// return properties; -// } -// -// Properties useReadWriteSplittingPlugin() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "readWriteSplitting"); -// return properties; -// } -// -// Properties useAuroraHostListAndReadWriteSplittingPlugin() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); -// return properties; -// } -// -// Properties useReadWriteSplittingPluginWithReaderLoadBalancing() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "readWriteSplitting"); -// properties.setProperty("loadBalanceReadOnlyTraffic", "true"); -// return properties; -// } -// -// Properties useAuroraHostListAndReadWriteSplittingPluginWithReaderLoadBalancing() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); -// properties.setProperty("loadBalanceReadOnlyTraffic", "true"); -// return properties; -// } -// -// Properties useTelemetry() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); -// properties.setProperty("enableTelemetry", "true"); -// properties.setProperty("telemetryMetricsBackend", "none"); -// properties.setProperty("telemetryTracesBackend", "none"); -// return properties; -// } -// -// Properties disabledTelemetry() { -// final Properties properties = new Properties(); -// properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); -// properties.setProperty("enableTelemetry", "false"); -// return properties; -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.benchmarks; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import com.zaxxer.hikari.HikariConfig; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.profile.GCProfiler; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.ConnectionProviderManager; +import software.amazon.jdbc.Driver; +import software.amazon.jdbc.HikariPooledConnectionProvider; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.PluginManagerService; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.benchmarks.testplugin.TestConnectionWrapper; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.monitoring.MonitorService; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.telemetry.GaugeCallable; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; +import software.amazon.jdbc.wrapper.ConnectionWrapper; + +@State(Scope.Benchmark) +@Fork(3) +@Warmup(iterations = 3) +@Measurement(iterations = 10) +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class PluginBenchmarks { + + private static final String WRITER_SESSION_ID = "MASTER_SESSION_ID"; + private static final String FIELD_SERVER_ID = "SERVER_ID"; + private static final String FIELD_SESSION_ID = "SESSION_ID"; + private static final String CONNECTION_STRING = "jdbc:postgresql://my.domain.com"; + private static final String PG_CONNECTION_STRING = + "jdbc:aws-wrapper:postgresql://instance-0.XYZ.us-east-2.rds.amazonaws.com"; + private static final String TEST_HOST = "instance-0"; + private static final int TEST_PORT = 5432; + private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host(TEST_HOST).port(TEST_PORT).build(); + + @Mock private StorageService mockStorageService; + @Mock private MonitorService mockMonitorService; + @Mock private PluginService mockPluginService; + @Mock private TargetDriverDialect mockTargetDriverDialect; + @Mock private Dialect mockDialect; + @Mock private ConnectionPluginManager mockConnectionPluginManager; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock TelemetryContext mockTelemetryContext; + @Mock TelemetryCounter mockTelemetryCounter; + @Mock TelemetryGauge mockTelemetryGauge; + @Mock private HostListProviderService mockHostListProviderService; + @Mock private PluginManagerService mockPluginManagerService; + @Mock ConnectionProvider mockConnectionProvider; + @Mock Connection mockConnection; + @Mock Statement mockStatement; + @Mock ResultSet mockResultSet; + private AutoCloseable closeable; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(PluginBenchmarks.class.getSimpleName()) + .addProfiler(GCProfiler.class) + .detectJvmArgs() + .build(); + + new Runner(opt).run(); + } + + @Setup(Level.Iteration) + public void setUpIteration() throws Exception { + closeable = MockitoAnnotations.openMocks(this); + when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean(), any())) + .thenReturn(mockConnection); + when(mockConnectionPluginManager.execute( + any(), any(), any(), eq(JdbcMethod.CONNECTION_CREATESTATEMENT), any(), any())) + .thenReturn(mockStatement); + when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); + // noinspection unchecked + when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); + when(mockConnectionProvider.connect( + anyString(), + any(Dialect.class), + any(TargetDriverDialect.class), + any(HostSpec.class), + any(Properties.class))).thenReturn(mockConnection); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true, false); + when(mockResultSet.getString(eq(FIELD_SESSION_ID))).thenReturn(WRITER_SESSION_ID); + when(mockResultSet.getString(eq(FIELD_SERVER_ID))) + .thenReturn("instance-0", "instance-1"); + when(mockResultSet.getStatement()).thenReturn(mockStatement); + when(mockStatement.getConnection()).thenReturn(mockConnection); + when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); + when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); + when(this.mockPluginService.getDialect()).thenReturn(mockDialect); + } + + @TearDown(Level.Iteration) + public void tearDownIteration() throws Exception { + closeable.close(); + } + + @Benchmark + public void initAndReleaseBaseLine() { + } + + @Benchmark + public ConnectionWrapper initAndReleaseWithExecutionTimePlugin() throws SQLException { + try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING)) { + wrapper.releaseResources(); + return wrapper; + } + } + + private ConnectionWrapper getConnectionWrapper(Properties props, String connString) throws SQLException { + return new TestConnectionWrapper( + props, + connString, + mockConnectionProvider, + mockTargetDriverDialect, + mockConnectionPluginManager, + mockTelemetryFactory, + mockPluginService, + mockHostListProviderService, + mockPluginManagerService, + mockStorageService, + mockMonitorService); + } + + @Benchmark + public ConnectionWrapper initAndReleaseWithAuroraHostListPlugin() throws SQLException { + try (ConnectionWrapper wrapper = getConnectionWrapper(useAuroraHostListPlugin(), CONNECTION_STRING)) { + wrapper.releaseResources(); + return wrapper; + } + } + + @Benchmark + public ConnectionWrapper initAndReleaseWithExecutionTimeAndAuroraHostListPlugins() throws SQLException { + try (ConnectionWrapper wrapper = + getConnectionWrapper(useExecutionTimeAndAuroraHostListPlugins(), CONNECTION_STRING)) { + wrapper.releaseResources(); + return wrapper; + } + } + + @Benchmark + public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin() throws SQLException { + try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { + wrapper.releaseResources(); + return wrapper; + } + } + + @Benchmark + public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin() + throws SQLException { + try (ConnectionWrapper wrapper = + getConnectionWrapper(useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { + wrapper.releaseResources(); + return wrapper; + } + } + + @Benchmark + public ConnectionWrapper initAndReleaseWithReadWriteSplittingPlugin_internalConnectionPools() throws SQLException { + HikariPooledConnectionProvider provider = + new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); + Driver.setCustomConnectionProvider(provider); + try (ConnectionWrapper wrapper = getConnectionWrapper(useReadWriteSplittingPlugin(), CONNECTION_STRING)) { + wrapper.releaseResources(); + ConnectionProviderManager.releaseResources(); + Driver.resetCustomConnectionProvider(); + return wrapper; + } + } + + @Benchmark + public ConnectionWrapper initAndReleaseWithAuroraHostListAndReadWriteSplittingPlugin_internalConnectionPools() + throws SQLException { + HikariPooledConnectionProvider provider = + new HikariPooledConnectionProvider((hostSpec, props) -> new HikariConfig()); + Driver.setCustomConnectionProvider(provider); + try (ConnectionWrapper wrapper = getConnectionWrapper( + useAuroraHostListAndReadWriteSplittingPlugin(), PG_CONNECTION_STRING)) { + wrapper.releaseResources(); + ConnectionProviderManager.releaseResources(); + Driver.resetCustomConnectionProvider(); + return wrapper; + } + } + + @Benchmark + public Statement executeStatementBaseline() throws SQLException { + try (ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); + Statement statement = wrapper.createStatement()) { + return statement; + } + } + + @Benchmark + public ResultSet executeStatementWithExecutionTimePlugin() throws SQLException { + try ( + ConnectionWrapper wrapper = getConnectionWrapper(useExecutionTimePlugin(), CONNECTION_STRING); + Statement statement = wrapper.createStatement(); + ResultSet resultSet = statement.executeQuery("some sql")) { + return resultSet; + } + } + + @Benchmark + public ResultSet executeStatementWithTelemetryDisabled() throws SQLException { + try ( + ConnectionWrapper wrapper = getConnectionWrapper(disabledTelemetry(), CONNECTION_STRING); + Statement statement = wrapper.createStatement(); + ResultSet resultSet = statement.executeQuery("some sql")) { + return resultSet; + } + } + + @Benchmark + public ResultSet executeStatementWithTelemetry() throws SQLException { + try ( + ConnectionWrapper wrapper = getConnectionWrapper(useTelemetry(), CONNECTION_STRING); + Statement statement = wrapper.createStatement(); + ResultSet resultSet = statement.executeQuery("some sql")) { + return resultSet; + } + } + + Properties useExecutionTimePlugin() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "executionTime"); + return properties; + } + + Properties useAuroraHostListPlugin() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "auroraHostList"); + return properties; + } + + Properties useExecutionTimeAndAuroraHostListPlugins() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "executionTime,auroraHostList"); + return properties; + } + + Properties useReadWriteSplittingPlugin() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "readWriteSplitting"); + return properties; + } + + Properties useAuroraHostListAndReadWriteSplittingPlugin() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); + return properties; + } + + Properties useReadWriteSplittingPluginWithReaderLoadBalancing() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "readWriteSplitting"); + properties.setProperty("loadBalanceReadOnlyTraffic", "true"); + return properties; + } + + Properties useAuroraHostListAndReadWriteSplittingPluginWithReaderLoadBalancing() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "auroraHostList,readWriteSplitting"); + properties.setProperty("loadBalanceReadOnlyTraffic", "true"); + return properties; + } + + Properties useTelemetry() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); + properties.setProperty("enableTelemetry", "true"); + properties.setProperty("telemetryMetricsBackend", "none"); + properties.setProperty("telemetryTracesBackend", "none"); + return properties; + } + + Properties disabledTelemetry() { + final Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "dataCache,auroraHostList,efm2"); + properties.setProperty("enableTelemetry", "false"); + return properties; + } +} diff --git a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java index f7216d85d..c35f6b0f8 100644 --- a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java +++ b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java @@ -1,33 +1,33 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package integration.container.aurora; -// -// import java.util.Properties; -// import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; -// import software.amazon.jdbc.util.FullServicesContainer; -// -// public class TestAuroraHostListProvider extends AuroraHostListProvider { -// -// public TestAuroraHostListProvider( -// FullServicesContainer servicesContainer, Properties properties, String originalUrl) { -// super(properties, originalUrl, servicesContainer, "", "", ""); -// } -// -// public static void clearCache() { -// AuroraHostListProvider.clearAll(); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package integration.container.aurora; + +import java.util.Properties; +import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; +import software.amazon.jdbc.util.FullServicesContainer; + +public class TestAuroraHostListProvider extends AuroraHostListProvider { + + public TestAuroraHostListProvider( + FullServicesContainer servicesContainer, Properties properties, String originalUrl) { + super(properties, originalUrl, servicesContainer, "", "", ""); + } + + public static void clearCache() { + AuroraHostListProvider.clearAll(); + } +} diff --git a/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java b/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java index 3fbbbb762..ae966c1d1 100644 --- a/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java +++ b/wrapper/src/test/java/integration/container/tests/AdvancedPerformanceTest.java @@ -1,767 +1,767 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package integration.container.tests; -// -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.junit.jupiter.api.Assertions.fail; -// import static software.amazon.jdbc.PropertyDefinition.CONNECT_TIMEOUT; -// import static software.amazon.jdbc.PropertyDefinition.PLUGINS; -// import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_COUNT; -// import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_INTERVAL; -// import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_TIME; -// import static software.amazon.jdbc.plugin.failover.FailoverConnectionPlugin.FAILOVER_TIMEOUT_MS; -// -// import integration.TestEnvironmentFeatures; -// import integration.container.ConnectionStringHelper; -// import integration.container.TestDriverProvider; -// import integration.container.TestEnvironment; -// import integration.container.aurora.TestAuroraHostListProvider; -// import integration.container.aurora.TestPluginServiceImpl; -// import integration.container.condition.DisableOnTestFeature; -// import integration.container.condition.EnableOnTestFeature; -// import integration.util.AuroraTestUtility; -// import java.io.File; -// import java.io.FileOutputStream; -// import java.io.IOException; -// import java.net.InetAddress; -// import java.net.UnknownHostException; -// import java.sql.Connection; -// import java.sql.DriverManager; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.ArrayList; -// import java.util.List; -// import java.util.Properties; -// import java.util.concurrent.ConcurrentLinkedQueue; -// import java.util.concurrent.CountDownLatch; -// import java.util.concurrent.TimeUnit; -// import java.util.concurrent.atomic.AtomicLong; -// import java.util.logging.Logger; -// import java.util.stream.Collectors; -// import java.util.stream.Stream; -// import org.apache.poi.ss.usermodel.Cell; -// import org.apache.poi.ss.usermodel.Row; -// import org.apache.poi.xssf.usermodel.XSSFSheet; -// import org.apache.poi.xssf.usermodel.XSSFWorkbook; -// import org.junit.jupiter.api.MethodOrderer; -// import org.junit.jupiter.api.Order; -// import org.junit.jupiter.api.Tag; -// import org.junit.jupiter.api.TestMethodOrder; -// import org.junit.jupiter.api.TestTemplate; -// import org.junit.jupiter.api.extension.ExtendWith; -// import org.junit.jupiter.params.provider.Arguments; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; -// import software.amazon.jdbc.plugin.efm2.HostMonitorServiceImpl; -// import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; -// import software.amazon.jdbc.util.StringUtils; -// -// @TestMethodOrder(MethodOrderer.MethodName.class) -// @ExtendWith(TestDriverProvider.class) -// @EnableOnTestFeature({ -// TestEnvironmentFeatures.PERFORMANCE, -// TestEnvironmentFeatures.FAILOVER_SUPPORTED -// }) -// @DisableOnTestFeature(TestEnvironmentFeatures.RUN_DB_METRICS_ONLY) -// @Tag("advanced") -// @Order(1) -// public class AdvancedPerformanceTest { -// -// private static final Logger LOGGER = Logger.getLogger(AdvancedPerformanceTest.class.getName()); -// -// private static final String MONITORING_CONNECTION_PREFIX = "monitoring-"; -// -// private static final int REPEAT_TIMES = -// StringUtils.isNullOrEmpty(System.getenv("REPEAT_TIMES")) -// ? 5 -// : Integer.parseInt(System.getenv("REPEAT_TIMES")); -// -// private static final int TIMEOUT_SEC = 5; -// private static final int CONNECT_TIMEOUT_SEC = 5; -// private static final int EFM_FAILOVER_TIMEOUT_MS = 300000; -// private static final int EFM_FAILURE_DETECTION_TIME_MS = 30000; -// private static final int EFM_FAILURE_DETECTION_INTERVAL_MS = 5000; -// private static final int EFM_FAILURE_DETECTION_COUNT = 3; -// private static final String QUERY = "SELECT pg_sleep(600)"; // 600s -> 10min -// -// private static final ConcurrentLinkedQueue perfDataList = new ConcurrentLinkedQueue<>(); -// -// protected static final AuroraTestUtility auroraUtil = AuroraTestUtility.getUtility(); -// -// private static void doWritePerfDataToFile( -// String fileName, ConcurrentLinkedQueue dataList) throws IOException { -// -// if (dataList.isEmpty()) { -// return; -// } -// -// LOGGER.finest(() -> "File name: " + fileName); -// -// List sortedData = -// dataList.stream() -// .sorted( -// (d1, d2) -> -// d1.paramFailoverDelayMillis == d2.paramFailoverDelayMillis -// ? d1.paramDriverName.compareTo(d2.paramDriverName) -// : 0) -// .collect(Collectors.toList()); -// -// try (XSSFWorkbook workbook = new XSSFWorkbook()) { -// -// final XSSFSheet sheet = workbook.createSheet("PerformanceResults"); -// -// for (int rows = 0; rows < dataList.size(); rows++) { -// PerfStat perfStat = sortedData.get(rows); -// Row row; -// -// if (rows == 0) { -// // Header -// row = sheet.createRow(0); -// perfStat.writeHeader(row); -// } -// -// row = sheet.createRow(rows + 1); -// perfStat.writeData(row); -// } -// -// // Write to file -// final File newExcelFile = new File(fileName); -// newExcelFile.createNewFile(); -// try (FileOutputStream fileOut = new FileOutputStream(newExcelFile)) { -// workbook.write(fileOut); -// } -// } -// } -// -// @TestTemplate -// public void test_AdvancedPerformance() throws IOException { -// -// perfDataList.clear(); -// -// try { -// Stream argsStream = generateParams(); -// argsStream.forEach( -// a -> { -// try { -// ensureClusterHealthy(); -// LOGGER.finest("DB cluster is healthy."); -// ensureDnsHealthy(); -// LOGGER.finest("DNS is healthy."); -// -// Object[] args = a.get(); -// int failoverDelayTimeMillis = (int) args[0]; -// int runNumber = (int) args[1]; -// -// LOGGER.finest( -// "Iteration " -// + runNumber -// + "/" -// + REPEAT_TIMES -// + " for " -// + failoverDelayTimeMillis -// + "ms delay"); -// -// doMeasurePerformance(failoverDelayTimeMillis); -// -// } catch (InterruptedException ex) { -// throw new RuntimeException(ex); -// } catch (UnknownHostException e) { -// throw new RuntimeException(e); -// } -// }); -// -// } finally { -// doWritePerfDataToFile( -// String.format( -// "./build/reports/tests/AdvancedPerformanceResults_" -// + "Db_%s_Driver_%s_Instances_%d.xlsx", -// TestEnvironment.getCurrent().getInfo().getRequest().getDatabaseEngine(), -// TestEnvironment.getCurrent().getCurrentDriver(), -// TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances()), -// perfDataList); -// perfDataList.clear(); -// } -// } -// -// private void doMeasurePerformance(int sleepDelayMillis) throws InterruptedException { -// -// final AtomicLong downtimeNano = new AtomicLong(); -// final CountDownLatch startLatch = new CountDownLatch(5); -// final CountDownLatch finishLatch = new CountDownLatch(5); -// -// downtimeNano.set(0); -// -// final Thread failoverThread = -// getThread_Failover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); -// final Thread pgThread = -// getThread_DirectDriver(sleepDelayMillis, downtimeNano, startLatch, finishLatch); -// final Thread wrapperEfmThread = -// getThread_WrapperEfm(sleepDelayMillis, downtimeNano, startLatch, finishLatch); -// final Thread wrapperEfmFailoverThread = -// getThread_WrapperEfmFailover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); -// final Thread dnsThread = getThread_DNS(sleepDelayMillis, downtimeNano, startLatch, finishLatch); -// -// failoverThread.start(); -// pgThread.start(); -// wrapperEfmThread.start(); -// wrapperEfmFailoverThread.start(); -// dnsThread.start(); -// -// LOGGER.finest("All threads started."); -// -// finishLatch.await(5, TimeUnit.MINUTES); // wait for all threads to complete -// -// LOGGER.finest("Test is over."); -// -// assertTrue(downtimeNano.get() > 0); -// -// failoverThread.interrupt(); -// pgThread.interrupt(); -// wrapperEfmThread.interrupt(); -// wrapperEfmFailoverThread.interrupt(); -// dnsThread.interrupt(); -// } -// -// private void ensureDnsHealthy() throws UnknownHostException, InterruptedException { -// LOGGER.finest( -// "Writer is " -// + TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getInstances() -// .get(0) -// .getInstanceId()); -// final String writerIpAddress = -// InetAddress.getByName( -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getInstances() -// .get(0) -// .getHost()) -// .getHostAddress(); -// LOGGER.finest("Writer resolves to " + writerIpAddress); -// LOGGER.finest( -// "Cluster Endpoint is " -// + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()); -// String clusterIpAddress = -// InetAddress.getByName( -// TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) -// .getHostAddress(); -// LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); -// -// long startTimeNano = System.nanoTime(); -// while (!clusterIpAddress.equals(writerIpAddress) -// && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { -// Thread.sleep(1000); -// clusterIpAddress = -// InetAddress.getByName( -// TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) -// .getHostAddress(); -// LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); -// } -// -// if (!clusterIpAddress.equals(writerIpAddress)) { -// fail("DNS has stale data"); -// } -// } -// -// private Thread getThread_Failover( -// final int sleepDelayMillis, -// final AtomicLong downtimeNano, -// final CountDownLatch startLatch, -// final CountDownLatch finishLatch) { -// -// return new Thread( -// () -> { -// try { -// Thread.sleep(1000); -// startLatch.countDown(); // notify that this thread is ready for work -// startLatch.await( -// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test -// -// LOGGER.finest("Waiting " + sleepDelayMillis + "ms..."); -// Thread.sleep(sleepDelayMillis); -// LOGGER.finest("Trigger failover..."); -// -// // trigger failover -// failoverCluster(); -// downtimeNano.set(System.nanoTime()); -// LOGGER.finest("Failover is started."); -// -// } catch (InterruptedException interruptedException) { -// // Ignore, stop the thread -// } catch (Exception exception) { -// fail("Failover thread exception: " + exception); -// } finally { -// finishLatch.countDown(); -// LOGGER.finest("Failover thread is completed."); -// } -// }); -// } -// -// private Thread getThread_DirectDriver( -// final int sleepDelayMillis, -// final AtomicLong downtimeNano, -// final CountDownLatch startLatch, -// final CountDownLatch finishLatch) { -// -// return new Thread( -// () -> { -// long failureTimeNano = 0; -// try { -// // DB_CONN_STR_PREFIX -// final Properties props = ConnectionStringHelper.getDefaultProperties(); -// final Connection conn = -// openConnectionWithRetry( -// ConnectionStringHelper.getUrl( -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpoint(), -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpointPort(), -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getDefaultDbName()), -// props); -// LOGGER.finest("DirectDriver connection is open."); -// -// Thread.sleep(1000); -// startLatch.countDown(); // notify that this thread is ready for work -// startLatch.await( -// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test -// -// LOGGER.finest("DirectDriver Starting long query..."); -// // Execute long query -// final Statement statement = conn.createStatement(); -// try (final ResultSet result = statement.executeQuery(QUERY)) { -// fail("Sleep query finished, should not be possible with network downed."); -// } catch (SQLException throwable) { // Catching executing query -// LOGGER.finest("DirectDriver thread exception: " + throwable); -// // Calculate and add detection time -// assertTrue(downtimeNano.get() > 0); -// failureTimeNano = System.nanoTime() - downtimeNano.get(); -// } -// -// } catch (InterruptedException interruptedException) { -// // Ignore, stop the thread -// } catch (Exception exception) { -// fail("PG thread exception: " + exception); -// } finally { -// PerfStat data = new PerfStat(); -// data.paramFailoverDelayMillis = sleepDelayMillis; -// data.paramDriverName = -// "DirectDriver - " + TestEnvironment.getCurrent().getCurrentDriver(); -// data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); -// LOGGER.finest("DirectDriver Collected data: " + data); -// perfDataList.add(data); -// LOGGER.finest( -// "DirectDriver Failure detection time is " + data.failureDetectionTimeMillis + "ms"); -// -// finishLatch.countDown(); -// LOGGER.finest("DirectDriver thread is completed."); -// } -// }); -// } -// -// private Thread getThread_WrapperEfm( -// final int sleepDelayMillis, -// final AtomicLong downtimeNano, -// final CountDownLatch startLatch, -// final CountDownLatch finishLatch) { -// -// return new Thread( -// () -> { -// long failureTimeNano = 0; -// try { -// final Properties props = ConnectionStringHelper.getDefaultProperties(); -// -// props.setProperty( -// MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, -// String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); -// props.setProperty( -// MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, -// String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); -// CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); -// -// FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); -// FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_INTERVAL_MS)); -// FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); -// PLUGINS.set(props, "efm"); -// -// final Connection conn = -// openConnectionWithRetry( -// ConnectionStringHelper.getWrapperUrl( -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpoint(), -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpointPort(), -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getDefaultDbName()), -// props); -// LOGGER.finest("WrapperEfm connection is open."); -// -// Thread.sleep(1000); -// startLatch.countDown(); // notify that this thread is ready for work -// startLatch.await( -// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test -// -// LOGGER.finest("WrapperEfm Starting long query..."); -// // Execute long query -// final Statement statement = conn.createStatement(); -// try (final ResultSet result = statement.executeQuery(QUERY)) { -// fail("Sleep query finished, should not be possible with network downed."); -// } catch (SQLException throwable) { // Catching executing query -// LOGGER.finest("WrapperEfm thread exception: " + throwable); -// -// // Calculate and add detection time -// assertTrue(downtimeNano.get() > 0); -// failureTimeNano = System.nanoTime() - downtimeNano.get(); -// } -// -// } catch (InterruptedException interruptedException) { -// // Ignore, stop the thread -// } catch (Exception exception) { -// fail("WrapperEfm thread exception: " + exception); -// } finally { -// PerfStat data = new PerfStat(); -// data.paramFailoverDelayMillis = sleepDelayMillis; -// data.paramDriverName = -// String.format( -// "AWS Wrapper (%s, EFM)", TestEnvironment.getCurrent().getCurrentDriver()); -// data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); -// LOGGER.finest("WrapperEfm Collected data: " + data); -// perfDataList.add(data); -// LOGGER.finest( -// "WrapperEfm Failure detection time is " + data.failureDetectionTimeMillis + "ms"); -// -// finishLatch.countDown(); -// LOGGER.finest("WrapperEfm thread is completed."); -// } -// }); -// } -// -// private Thread getThread_WrapperEfmFailover( -// final int sleepDelayMillis, -// final AtomicLong downtimeNano, -// final CountDownLatch startLatch, -// final CountDownLatch finishLatch) { -// -// return new Thread( -// () -> { -// long failureTimeNano = 0; -// try { -// final Properties props = ConnectionStringHelper.getDefaultProperties(); -// -// props.setProperty( -// MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, -// String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); -// props.setProperty( -// MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, -// String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); -// CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); -// -// FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); -// FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); -// FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); -// FAILOVER_TIMEOUT_MS.set(props, Integer.toString(EFM_FAILOVER_TIMEOUT_MS)); -// PLUGINS.set(props, "failover,efm"); -// -// final Connection conn = -// openConnectionWithRetry( -// ConnectionStringHelper.getWrapperUrl( -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpoint(), -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpointPort(), -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getDefaultDbName()), -// props); -// LOGGER.finest("WrapperEfmFailover connection is open."); -// -// Thread.sleep(1000); -// startLatch.countDown(); // notify that this thread is ready for work -// startLatch.await( -// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test -// -// LOGGER.finest("WrapperEfmFailover Starting long query..."); -// // Execute long query -// final Statement statement = conn.createStatement(); -// try (final ResultSet result = statement.executeQuery(QUERY)) { -// fail("Sleep query finished, should not be possible with network downed."); -// } catch (SQLException throwable) { -// LOGGER.finest("WrapperEfmFailover thread exception: " + throwable); -// if (throwable instanceof FailoverSuccessSQLException) { -// // Calculate and add detection time -// assertTrue(downtimeNano.get() > 0); -// failureTimeNano = System.nanoTime() - downtimeNano.get(); -// } -// } -// -// } catch (InterruptedException interruptedException) { -// // Ignore, stop the thread -// } catch (Exception exception) { -// fail("WrapperEfmFailover thread exception: " + exception); -// } finally { -// PerfStat data = new PerfStat(); -// data.paramFailoverDelayMillis = sleepDelayMillis; -// data.paramDriverName = -// String.format( -// "AWS Wrapper (%s, EFM, Failover)", -// TestEnvironment.getCurrent().getCurrentDriver()); -// data.reconnectTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); -// LOGGER.finest("WrapperEfmFailover Collected data: " + data); -// perfDataList.add(data); -// LOGGER.finest( -// "WrapperEfmFailover Reconnect time is " + data.reconnectTimeMillis + "ms"); -// -// finishLatch.countDown(); -// LOGGER.finest("WrapperEfmFailover thread is completed."); -// } -// }); -// } -// -// private Thread getThread_DNS( -// final int sleepDelayMillis, -// final AtomicLong downtimeNano, -// final CountDownLatch startLatch, -// final CountDownLatch finishLatch) { -// -// return new Thread( -// () -> { -// long failureTimeNano = 0; -// String currentClusterIpAddress; -// -// try { -// currentClusterIpAddress = -// InetAddress.getByName( -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpoint()) -// .getHostAddress(); -// LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); -// -// Thread.sleep(1000); -// startLatch.countDown(); // notify that this thread is ready for work -// startLatch.await( -// 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test -// -// String clusterIpAddress = -// InetAddress.getByName( -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpoint()) -// .getHostAddress(); -// -// long startTimeNano = System.nanoTime(); -// while (clusterIpAddress.equals(currentClusterIpAddress) -// && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { -// Thread.sleep(1000); -// clusterIpAddress = -// InetAddress.getByName( -// TestEnvironment.getCurrent() -// .getInfo() -// .getDatabaseInfo() -// .getClusterEndpoint()) -// .getHostAddress(); -// LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); -// } -// -// // DNS data has changed -// if (!clusterIpAddress.equals(currentClusterIpAddress)) { -// assertTrue(downtimeNano.get() > 0); -// failureTimeNano = System.nanoTime() - downtimeNano.get(); -// } -// -// } catch (InterruptedException interruptedException) { -// // Ignore, stop the thread -// } catch (Exception exception) { -// fail("Failover thread exception: " + exception); -// } finally { -// PerfStat data = new PerfStat(); -// data.paramFailoverDelayMillis = sleepDelayMillis; -// data.paramDriverName = "DNS"; -// data.dnsUpdateTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); -// LOGGER.finest("DNS Collected data: " + data); -// perfDataList.add(data); -// LOGGER.finest("DNS Update time is " + data.dnsUpdateTimeMillis + "ms"); -// -// finishLatch.countDown(); -// LOGGER.finest("DNS thread is completed."); -// } -// }); -// } -// -// private Connection openConnectionWithRetry(String url, Properties props) { -// Connection conn = null; -// int connectCount = 0; -// while (conn == null && connectCount < 10) { -// try { -// conn = DriverManager.getConnection(url, props); -// -// } catch (SQLException sqlEx) { -// // ignore, try to connect again -// } -// connectCount++; -// } -// -// if (conn == null) { -// fail("Can't connect to " + url); -// } -// return conn; -// } -// -// private void failoverCluster() throws InterruptedException { -// String clusterId = TestEnvironment.getCurrent().getInfo().getRdsDbName(); -// String randomNode = auroraUtil.getRandomDBClusterReaderInstanceId(clusterId); -// auroraUtil.failoverClusterToTarget(clusterId, randomNode); -// } -// -// private void ensureClusterHealthy() throws InterruptedException { -// -// auroraUtil.waitUntilClusterHasRightState( -// TestEnvironment.getCurrent().getInfo().getRdsDbName()); -// -// // Always get the latest topology info with writer as first -// List latestTopology = new ArrayList<>(); -// -// // Need to ensure that cluster details through API matches topology fetched through SQL -// // Wait up to 5min -// long startTimeNano = System.nanoTime(); -// while ((latestTopology.size() -// != TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances() -// || !auroraUtil.isDBInstanceWriter(latestTopology.get(0))) -// && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { -// -// Thread.sleep(5000); -// -// try { -// latestTopology = auroraUtil.getAuroraInstanceIds(); -// } catch (SQLException ex) { -// latestTopology = new ArrayList<>(); -// } -// } -// assertTrue( -// auroraUtil.isDBInstanceWriter( -// TestEnvironment.getCurrent().getInfo().getRdsDbName(), latestTopology.get(0))); -// String currentWriter = latestTopology.get(0); -// -// // Adjust database info to reflect a current writer and to move corresponding instance to -// // position 0. -// TestEnvironment.getCurrent().getInfo().getDatabaseInfo().moveInstanceFirst(currentWriter); -// TestEnvironment.getCurrent().getInfo().getProxyDatabaseInfo().moveInstanceFirst(currentWriter); -// -// auroraUtil.makeSureInstancesUp(TimeUnit.MINUTES.toSeconds(5)); -// -// TestAuroraHostListProvider.clearCache(); -// TestPluginServiceImpl.clearHostAvailabilityCache(); -// HostMonitorThreadContainer.releaseInstance(); -// HostMonitorServiceImpl.closeAllMonitors(); -// } -// -// private static Stream generateParams() { -// -// ArrayList args = new ArrayList<>(); -// -// for (int i = 1; i <= REPEAT_TIMES; i++) { -// args.add(Arguments.of(10000, i)); -// } -// for (int i = 1; i <= REPEAT_TIMES; i++) { -// args.add(Arguments.of(20000, i)); -// } -// for (int i = 1; i <= REPEAT_TIMES; i++) { -// args.add(Arguments.of(30000, i)); -// } -// for (int i = 1; i <= REPEAT_TIMES; i++) { -// args.add(Arguments.of(40000, i)); -// } -// for (int i = 1; i <= REPEAT_TIMES; i++) { -// args.add(Arguments.of(50000, i)); -// } -// for (int i = 1; i <= REPEAT_TIMES; i++) { -// args.add(Arguments.of(60000, i)); -// } -// -// return Stream.of(args.toArray(new Arguments[0])); -// } -// -// private static class PerfStat { -// -// public String paramDriverName; -// public int paramFailoverDelayMillis; -// public long failureDetectionTimeMillis; -// public long reconnectTimeMillis; -// public long dnsUpdateTimeMillis; -// -// public void writeHeader(Row row) { -// Cell cell = row.createCell(0); -// cell.setCellValue("Driver Configuration"); -// cell = row.createCell(1); -// cell.setCellValue("Failover Delay Millis"); -// cell = row.createCell(2); -// cell.setCellValue("Failure Detection Time Millis"); -// cell = row.createCell(3); -// cell.setCellValue("Reconnect Time Millis"); -// cell = row.createCell(4); -// cell.setCellValue("DNS Update Time Millis"); -// } -// -// public void writeData(Row row) { -// Cell cell = row.createCell(0); -// cell.setCellValue(this.paramDriverName); -// cell = row.createCell(1); -// cell.setCellValue(this.paramFailoverDelayMillis); -// cell = row.createCell(2); -// cell.setCellValue(this.failureDetectionTimeMillis); -// cell = row.createCell(3); -// cell.setCellValue(this.reconnectTimeMillis); -// cell = row.createCell(4); -// cell.setCellValue(this.dnsUpdateTimeMillis); -// } -// -// @Override -// public String toString() { -// return String.format("%s [\nparamDriverName=%s,\nparamFailoverDelayMillis=%d,\n" -// + "failureDetectionTimeMillis=%d,\nreconnectTimeMillis=%d,\ndnsUpdateTimeMillis=%d ]", -// super.toString(), -// this.paramDriverName, -// this.paramFailoverDelayMillis, -// this.failureDetectionTimeMillis, -// this.reconnectTimeMillis, -// this.dnsUpdateTimeMillis); -// } -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package integration.container.tests; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static software.amazon.jdbc.PropertyDefinition.CONNECT_TIMEOUT; +import static software.amazon.jdbc.PropertyDefinition.PLUGINS; +import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_COUNT; +import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_INTERVAL; +import static software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin.FAILURE_DETECTION_TIME; +import static software.amazon.jdbc.plugin.failover.FailoverConnectionPlugin.FAILOVER_TIMEOUT_MS; + +import integration.TestEnvironmentFeatures; +import integration.container.ConnectionStringHelper; +import integration.container.TestDriverProvider; +import integration.container.TestEnvironment; +import integration.container.aurora.TestAuroraHostListProvider; +import integration.container.aurora.TestPluginServiceImpl; +import integration.container.condition.DisableOnTestFeature; +import integration.container.condition.EnableOnTestFeature; +import integration.util.AuroraTestUtility; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.poi.ss.usermodel.Cell; +import org.apache.poi.ss.usermodel.Row; +import org.apache.poi.xssf.usermodel.XSSFSheet; +import org.apache.poi.xssf.usermodel.XSSFWorkbook; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.provider.Arguments; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; +import software.amazon.jdbc.plugin.efm2.HostMonitorServiceImpl; +import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; +import software.amazon.jdbc.util.StringUtils; + +@TestMethodOrder(MethodOrderer.MethodName.class) +@ExtendWith(TestDriverProvider.class) +@EnableOnTestFeature({ + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.FAILOVER_SUPPORTED +}) +@DisableOnTestFeature(TestEnvironmentFeatures.RUN_DB_METRICS_ONLY) +@Tag("advanced") +@Order(1) +public class AdvancedPerformanceTest { + + private static final Logger LOGGER = Logger.getLogger(AdvancedPerformanceTest.class.getName()); + + private static final String MONITORING_CONNECTION_PREFIX = "monitoring-"; + + private static final int REPEAT_TIMES = + StringUtils.isNullOrEmpty(System.getenv("REPEAT_TIMES")) + ? 5 + : Integer.parseInt(System.getenv("REPEAT_TIMES")); + + private static final int TIMEOUT_SEC = 5; + private static final int CONNECT_TIMEOUT_SEC = 5; + private static final int EFM_FAILOVER_TIMEOUT_MS = 300000; + private static final int EFM_FAILURE_DETECTION_TIME_MS = 30000; + private static final int EFM_FAILURE_DETECTION_INTERVAL_MS = 5000; + private static final int EFM_FAILURE_DETECTION_COUNT = 3; + private static final String QUERY = "SELECT pg_sleep(600)"; // 600s -> 10min + + private static final ConcurrentLinkedQueue perfDataList = new ConcurrentLinkedQueue<>(); + + protected static final AuroraTestUtility auroraUtil = AuroraTestUtility.getUtility(); + + private static void doWritePerfDataToFile( + String fileName, ConcurrentLinkedQueue dataList) throws IOException { + + if (dataList.isEmpty()) { + return; + } + + LOGGER.finest(() -> "File name: " + fileName); + + List sortedData = + dataList.stream() + .sorted( + (d1, d2) -> + d1.paramFailoverDelayMillis == d2.paramFailoverDelayMillis + ? d1.paramDriverName.compareTo(d2.paramDriverName) + : 0) + .collect(Collectors.toList()); + + try (XSSFWorkbook workbook = new XSSFWorkbook()) { + + final XSSFSheet sheet = workbook.createSheet("PerformanceResults"); + + for (int rows = 0; rows < dataList.size(); rows++) { + PerfStat perfStat = sortedData.get(rows); + Row row; + + if (rows == 0) { + // Header + row = sheet.createRow(0); + perfStat.writeHeader(row); + } + + row = sheet.createRow(rows + 1); + perfStat.writeData(row); + } + + // Write to file + final File newExcelFile = new File(fileName); + newExcelFile.createNewFile(); + try (FileOutputStream fileOut = new FileOutputStream(newExcelFile)) { + workbook.write(fileOut); + } + } + } + + @TestTemplate + public void test_AdvancedPerformance() throws IOException { + + perfDataList.clear(); + + try { + Stream argsStream = generateParams(); + argsStream.forEach( + a -> { + try { + ensureClusterHealthy(); + LOGGER.finest("DB cluster is healthy."); + ensureDnsHealthy(); + LOGGER.finest("DNS is healthy."); + + Object[] args = a.get(); + int failoverDelayTimeMillis = (int) args[0]; + int runNumber = (int) args[1]; + + LOGGER.finest( + "Iteration " + + runNumber + + "/" + + REPEAT_TIMES + + " for " + + failoverDelayTimeMillis + + "ms delay"); + + doMeasurePerformance(failoverDelayTimeMillis); + + } catch (InterruptedException ex) { + throw new RuntimeException(ex); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + }); + + } finally { + doWritePerfDataToFile( + String.format( + "./build/reports/tests/AdvancedPerformanceResults_" + + "Db_%s_Driver_%s_Instances_%d.xlsx", + TestEnvironment.getCurrent().getInfo().getRequest().getDatabaseEngine(), + TestEnvironment.getCurrent().getCurrentDriver(), + TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances()), + perfDataList); + perfDataList.clear(); + } + } + + private void doMeasurePerformance(int sleepDelayMillis) throws InterruptedException { + + final AtomicLong downtimeNano = new AtomicLong(); + final CountDownLatch startLatch = new CountDownLatch(5); + final CountDownLatch finishLatch = new CountDownLatch(5); + + downtimeNano.set(0); + + final Thread failoverThread = + getThread_Failover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); + final Thread pgThread = + getThread_DirectDriver(sleepDelayMillis, downtimeNano, startLatch, finishLatch); + final Thread wrapperEfmThread = + getThread_WrapperEfm(sleepDelayMillis, downtimeNano, startLatch, finishLatch); + final Thread wrapperEfmFailoverThread = + getThread_WrapperEfmFailover(sleepDelayMillis, downtimeNano, startLatch, finishLatch); + final Thread dnsThread = getThread_DNS(sleepDelayMillis, downtimeNano, startLatch, finishLatch); + + failoverThread.start(); + pgThread.start(); + wrapperEfmThread.start(); + wrapperEfmFailoverThread.start(); + dnsThread.start(); + + LOGGER.finest("All threads started."); + + finishLatch.await(5, TimeUnit.MINUTES); // wait for all threads to complete + + LOGGER.finest("Test is over."); + + assertTrue(downtimeNano.get() > 0); + + failoverThread.interrupt(); + pgThread.interrupt(); + wrapperEfmThread.interrupt(); + wrapperEfmFailoverThread.interrupt(); + dnsThread.interrupt(); + } + + private void ensureDnsHealthy() throws UnknownHostException, InterruptedException { + LOGGER.finest( + "Writer is " + + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getInstances() + .get(0) + .getInstanceId()); + final String writerIpAddress = + InetAddress.getByName( + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getInstances() + .get(0) + .getHost()) + .getHostAddress(); + LOGGER.finest("Writer resolves to " + writerIpAddress); + LOGGER.finest( + "Cluster Endpoint is " + + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()); + String clusterIpAddress = + InetAddress.getByName( + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) + .getHostAddress(); + LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); + + long startTimeNano = System.nanoTime(); + while (!clusterIpAddress.equals(writerIpAddress) + && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { + Thread.sleep(1000); + clusterIpAddress = + InetAddress.getByName( + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getClusterEndpoint()) + .getHostAddress(); + LOGGER.finest("Cluster Endpoint resolves to " + clusterIpAddress); + } + + if (!clusterIpAddress.equals(writerIpAddress)) { + fail("DNS has stale data"); + } + } + + private Thread getThread_Failover( + final int sleepDelayMillis, + final AtomicLong downtimeNano, + final CountDownLatch startLatch, + final CountDownLatch finishLatch) { + + return new Thread( + () -> { + try { + Thread.sleep(1000); + startLatch.countDown(); // notify that this thread is ready for work + startLatch.await( + 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test + + LOGGER.finest("Waiting " + sleepDelayMillis + "ms..."); + Thread.sleep(sleepDelayMillis); + LOGGER.finest("Trigger failover..."); + + // trigger failover + failoverCluster(); + downtimeNano.set(System.nanoTime()); + LOGGER.finest("Failover is started."); + + } catch (InterruptedException interruptedException) { + // Ignore, stop the thread + } catch (Exception exception) { + fail("Failover thread exception: " + exception); + } finally { + finishLatch.countDown(); + LOGGER.finest("Failover thread is completed."); + } + }); + } + + private Thread getThread_DirectDriver( + final int sleepDelayMillis, + final AtomicLong downtimeNano, + final CountDownLatch startLatch, + final CountDownLatch finishLatch) { + + return new Thread( + () -> { + long failureTimeNano = 0; + try { + // DB_CONN_STR_PREFIX + final Properties props = ConnectionStringHelper.getDefaultProperties(); + final Connection conn = + openConnectionWithRetry( + ConnectionStringHelper.getUrl( + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpoint(), + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpointPort(), + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getDefaultDbName()), + props); + LOGGER.finest("DirectDriver connection is open."); + + Thread.sleep(1000); + startLatch.countDown(); // notify that this thread is ready for work + startLatch.await( + 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test + + LOGGER.finest("DirectDriver Starting long query..."); + // Execute long query + final Statement statement = conn.createStatement(); + try (final ResultSet result = statement.executeQuery(QUERY)) { + fail("Sleep query finished, should not be possible with network downed."); + } catch (SQLException throwable) { // Catching executing query + LOGGER.finest("DirectDriver thread exception: " + throwable); + // Calculate and add detection time + assertTrue(downtimeNano.get() > 0); + failureTimeNano = System.nanoTime() - downtimeNano.get(); + } + + } catch (InterruptedException interruptedException) { + // Ignore, stop the thread + } catch (Exception exception) { + fail("PG thread exception: " + exception); + } finally { + PerfStat data = new PerfStat(); + data.paramFailoverDelayMillis = sleepDelayMillis; + data.paramDriverName = + "DirectDriver - " + TestEnvironment.getCurrent().getCurrentDriver(); + data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); + LOGGER.finest("DirectDriver Collected data: " + data); + perfDataList.add(data); + LOGGER.finest( + "DirectDriver Failure detection time is " + data.failureDetectionTimeMillis + "ms"); + + finishLatch.countDown(); + LOGGER.finest("DirectDriver thread is completed."); + } + }); + } + + private Thread getThread_WrapperEfm( + final int sleepDelayMillis, + final AtomicLong downtimeNano, + final CountDownLatch startLatch, + final CountDownLatch finishLatch) { + + return new Thread( + () -> { + long failureTimeNano = 0; + try { + final Properties props = ConnectionStringHelper.getDefaultProperties(); + + props.setProperty( + MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, + String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); + props.setProperty( + MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, + String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); + CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); + + FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); + FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_INTERVAL_MS)); + FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); + PLUGINS.set(props, "efm"); + + final Connection conn = + openConnectionWithRetry( + ConnectionStringHelper.getWrapperUrl( + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpoint(), + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpointPort(), + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getDefaultDbName()), + props); + LOGGER.finest("WrapperEfm connection is open."); + + Thread.sleep(1000); + startLatch.countDown(); // notify that this thread is ready for work + startLatch.await( + 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test + + LOGGER.finest("WrapperEfm Starting long query..."); + // Execute long query + final Statement statement = conn.createStatement(); + try (final ResultSet result = statement.executeQuery(QUERY)) { + fail("Sleep query finished, should not be possible with network downed."); + } catch (SQLException throwable) { // Catching executing query + LOGGER.finest("WrapperEfm thread exception: " + throwable); + + // Calculate and add detection time + assertTrue(downtimeNano.get() > 0); + failureTimeNano = System.nanoTime() - downtimeNano.get(); + } + + } catch (InterruptedException interruptedException) { + // Ignore, stop the thread + } catch (Exception exception) { + fail("WrapperEfm thread exception: " + exception); + } finally { + PerfStat data = new PerfStat(); + data.paramFailoverDelayMillis = sleepDelayMillis; + data.paramDriverName = + String.format( + "AWS Wrapper (%s, EFM)", TestEnvironment.getCurrent().getCurrentDriver()); + data.failureDetectionTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); + LOGGER.finest("WrapperEfm Collected data: " + data); + perfDataList.add(data); + LOGGER.finest( + "WrapperEfm Failure detection time is " + data.failureDetectionTimeMillis + "ms"); + + finishLatch.countDown(); + LOGGER.finest("WrapperEfm thread is completed."); + } + }); + } + + private Thread getThread_WrapperEfmFailover( + final int sleepDelayMillis, + final AtomicLong downtimeNano, + final CountDownLatch startLatch, + final CountDownLatch finishLatch) { + + return new Thread( + () -> { + long failureTimeNano = 0; + try { + final Properties props = ConnectionStringHelper.getDefaultProperties(); + + props.setProperty( + MONITORING_CONNECTION_PREFIX + PropertyDefinition.CONNECT_TIMEOUT.name, + String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); + props.setProperty( + MONITORING_CONNECTION_PREFIX + PropertyDefinition.SOCKET_TIMEOUT.name, + String.valueOf(TimeUnit.SECONDS.toMillis(TIMEOUT_SEC))); + CONNECT_TIMEOUT.set(props, String.valueOf(TimeUnit.SECONDS.toMillis(CONNECT_TIMEOUT_SEC))); + + FAILURE_DETECTION_TIME.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); + FAILURE_DETECTION_INTERVAL.set(props, Integer.toString(EFM_FAILURE_DETECTION_TIME_MS)); + FAILURE_DETECTION_COUNT.set(props, Integer.toString(EFM_FAILURE_DETECTION_COUNT)); + FAILOVER_TIMEOUT_MS.set(props, Integer.toString(EFM_FAILOVER_TIMEOUT_MS)); + PLUGINS.set(props, "failover,efm"); + + final Connection conn = + openConnectionWithRetry( + ConnectionStringHelper.getWrapperUrl( + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpoint(), + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpointPort(), + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getDefaultDbName()), + props); + LOGGER.finest("WrapperEfmFailover connection is open."); + + Thread.sleep(1000); + startLatch.countDown(); // notify that this thread is ready for work + startLatch.await( + 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test + + LOGGER.finest("WrapperEfmFailover Starting long query..."); + // Execute long query + final Statement statement = conn.createStatement(); + try (final ResultSet result = statement.executeQuery(QUERY)) { + fail("Sleep query finished, should not be possible with network downed."); + } catch (SQLException throwable) { + LOGGER.finest("WrapperEfmFailover thread exception: " + throwable); + if (throwable instanceof FailoverSuccessSQLException) { + // Calculate and add detection time + assertTrue(downtimeNano.get() > 0); + failureTimeNano = System.nanoTime() - downtimeNano.get(); + } + } + + } catch (InterruptedException interruptedException) { + // Ignore, stop the thread + } catch (Exception exception) { + fail("WrapperEfmFailover thread exception: " + exception); + } finally { + PerfStat data = new PerfStat(); + data.paramFailoverDelayMillis = sleepDelayMillis; + data.paramDriverName = + String.format( + "AWS Wrapper (%s, EFM, Failover)", + TestEnvironment.getCurrent().getCurrentDriver()); + data.reconnectTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); + LOGGER.finest("WrapperEfmFailover Collected data: " + data); + perfDataList.add(data); + LOGGER.finest( + "WrapperEfmFailover Reconnect time is " + data.reconnectTimeMillis + "ms"); + + finishLatch.countDown(); + LOGGER.finest("WrapperEfmFailover thread is completed."); + } + }); + } + + private Thread getThread_DNS( + final int sleepDelayMillis, + final AtomicLong downtimeNano, + final CountDownLatch startLatch, + final CountDownLatch finishLatch) { + + return new Thread( + () -> { + long failureTimeNano = 0; + String currentClusterIpAddress; + + try { + currentClusterIpAddress = + InetAddress.getByName( + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpoint()) + .getHostAddress(); + LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); + + Thread.sleep(1000); + startLatch.countDown(); // notify that this thread is ready for work + startLatch.await( + 5, TimeUnit.MINUTES); // wait for another threads to be ready to start the test + + String clusterIpAddress = + InetAddress.getByName( + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpoint()) + .getHostAddress(); + + long startTimeNano = System.nanoTime(); + while (clusterIpAddress.equals(currentClusterIpAddress) + && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { + Thread.sleep(1000); + clusterIpAddress = + InetAddress.getByName( + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getClusterEndpoint()) + .getHostAddress(); + LOGGER.finest("Cluster Endpoint resolves to " + currentClusterIpAddress); + } + + // DNS data has changed + if (!clusterIpAddress.equals(currentClusterIpAddress)) { + assertTrue(downtimeNano.get() > 0); + failureTimeNano = System.nanoTime() - downtimeNano.get(); + } + + } catch (InterruptedException interruptedException) { + // Ignore, stop the thread + } catch (Exception exception) { + fail("Failover thread exception: " + exception); + } finally { + PerfStat data = new PerfStat(); + data.paramFailoverDelayMillis = sleepDelayMillis; + data.paramDriverName = "DNS"; + data.dnsUpdateTimeMillis = TimeUnit.NANOSECONDS.toMillis(failureTimeNano); + LOGGER.finest("DNS Collected data: " + data); + perfDataList.add(data); + LOGGER.finest("DNS Update time is " + data.dnsUpdateTimeMillis + "ms"); + + finishLatch.countDown(); + LOGGER.finest("DNS thread is completed."); + } + }); + } + + private Connection openConnectionWithRetry(String url, Properties props) { + Connection conn = null; + int connectCount = 0; + while (conn == null && connectCount < 10) { + try { + conn = DriverManager.getConnection(url, props); + + } catch (SQLException sqlEx) { + // ignore, try to connect again + } + connectCount++; + } + + if (conn == null) { + fail("Can't connect to " + url); + } + return conn; + } + + private void failoverCluster() throws InterruptedException { + String clusterId = TestEnvironment.getCurrent().getInfo().getRdsDbName(); + String randomNode = auroraUtil.getRandomDBClusterReaderInstanceId(clusterId); + auroraUtil.failoverClusterToTarget(clusterId, randomNode); + } + + private void ensureClusterHealthy() throws InterruptedException { + + auroraUtil.waitUntilClusterHasRightState( + TestEnvironment.getCurrent().getInfo().getRdsDbName()); + + // Always get the latest topology info with writer as first + List latestTopology = new ArrayList<>(); + + // Need to ensure that cluster details through API matches topology fetched through SQL + // Wait up to 5min + long startTimeNano = System.nanoTime(); + while ((latestTopology.size() + != TestEnvironment.getCurrent().getInfo().getRequest().getNumOfInstances() + || !auroraUtil.isDBInstanceWriter(latestTopology.get(0))) + && TimeUnit.NANOSECONDS.toMinutes(System.nanoTime() - startTimeNano) < 5) { + + Thread.sleep(5000); + + try { + latestTopology = auroraUtil.getAuroraInstanceIds(); + } catch (SQLException ex) { + latestTopology = new ArrayList<>(); + } + } + assertTrue( + auroraUtil.isDBInstanceWriter( + TestEnvironment.getCurrent().getInfo().getRdsDbName(), latestTopology.get(0))); + String currentWriter = latestTopology.get(0); + + // Adjust database info to reflect a current writer and to move corresponding instance to + // position 0. + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().moveInstanceFirst(currentWriter); + TestEnvironment.getCurrent().getInfo().getProxyDatabaseInfo().moveInstanceFirst(currentWriter); + + auroraUtil.makeSureInstancesUp(TimeUnit.MINUTES.toSeconds(5)); + + TestAuroraHostListProvider.clearCache(); + TestPluginServiceImpl.clearHostAvailabilityCache(); + HostMonitorThreadContainer.releaseInstance(); + HostMonitorServiceImpl.closeAllMonitors(); + } + + private static Stream generateParams() { + + ArrayList args = new ArrayList<>(); + + for (int i = 1; i <= REPEAT_TIMES; i++) { + args.add(Arguments.of(10000, i)); + } + for (int i = 1; i <= REPEAT_TIMES; i++) { + args.add(Arguments.of(20000, i)); + } + for (int i = 1; i <= REPEAT_TIMES; i++) { + args.add(Arguments.of(30000, i)); + } + for (int i = 1; i <= REPEAT_TIMES; i++) { + args.add(Arguments.of(40000, i)); + } + for (int i = 1; i <= REPEAT_TIMES; i++) { + args.add(Arguments.of(50000, i)); + } + for (int i = 1; i <= REPEAT_TIMES; i++) { + args.add(Arguments.of(60000, i)); + } + + return Stream.of(args.toArray(new Arguments[0])); + } + + private static class PerfStat { + + public String paramDriverName; + public int paramFailoverDelayMillis; + public long failureDetectionTimeMillis; + public long reconnectTimeMillis; + public long dnsUpdateTimeMillis; + + public void writeHeader(Row row) { + Cell cell = row.createCell(0); + cell.setCellValue("Driver Configuration"); + cell = row.createCell(1); + cell.setCellValue("Failover Delay Millis"); + cell = row.createCell(2); + cell.setCellValue("Failure Detection Time Millis"); + cell = row.createCell(3); + cell.setCellValue("Reconnect Time Millis"); + cell = row.createCell(4); + cell.setCellValue("DNS Update Time Millis"); + } + + public void writeData(Row row) { + Cell cell = row.createCell(0); + cell.setCellValue(this.paramDriverName); + cell = row.createCell(1); + cell.setCellValue(this.paramFailoverDelayMillis); + cell = row.createCell(2); + cell.setCellValue(this.failureDetectionTimeMillis); + cell = row.createCell(3); + cell.setCellValue(this.reconnectTimeMillis); + cell = row.createCell(4); + cell.setCellValue(this.dnsUpdateTimeMillis); + } + + @Override + public String toString() { + return String.format("%s [\nparamDriverName=%s,\nparamFailoverDelayMillis=%d,\n" + + "failureDetectionTimeMillis=%d,\nreconnectTimeMillis=%d,\ndnsUpdateTimeMillis=%d ]", + super.toString(), + this.paramDriverName, + this.paramFailoverDelayMillis, + this.failureDetectionTimeMillis, + this.reconnectTimeMillis, + this.dnsUpdateTimeMillis); + } + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java index 45ed19fd4..355276ad3 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java @@ -1,970 +1,970 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc; -// -// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.mock; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.HashSet; -// import java.util.List; -// import java.util.Properties; -// import java.util.concurrent.CompletableFuture; -// import java.util.concurrent.CountDownLatch; -// import java.util.concurrent.TimeUnit; -// import java.util.concurrent.atomic.AtomicBoolean; -// import java.util.concurrent.locks.ReentrantLock; -// import java.util.logging.Logger; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.mock.TestPluginOne; -// import software.amazon.jdbc.mock.TestPluginThree; -// import software.amazon.jdbc.mock.TestPluginThrowException; -// import software.amazon.jdbc.mock.TestPluginTwo; -// import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; -// import software.amazon.jdbc.plugin.DefaultConnectionPlugin; -// import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; -// import software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin; -// import software.amazon.jdbc.profile.ConfigurationProfile; -// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.WrapperUtils; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.wrapper.ConnectionWrapper; -// -// public class ConnectionPluginManagerTests { -// -// private static final Logger LOGGER = Logger.getLogger(ConnectionPluginManagerTests.class.getName()); -// -// @Mock JdbcCallable mockSqlFunction; -// @Mock ConnectionProvider mockConnectionProvider; -// @Mock ConnectionWrapper mockConnectionWrapper; -// @Mock TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock FullServicesContainer mockServicesContainer; -// @Mock PluginService mockPluginService; -// @Mock PluginManagerService mockPluginManagerService; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// -// ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); -// -// private AutoCloseable closeable; -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @BeforeEach -// void init() { -// closeable = MockitoAnnotations.openMocks(this); -// when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); -// when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); -// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); -// } -// -// @Test -// public void testExecuteJdbcCallA() throws Exception { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThree(calls)); -// -// final Properties testProperties = new Properties(); -// -// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; -// -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Object result = -// target.execute( -// String.class, -// Exception.class, -// Connection.class, -// JdbcMethod.BLOB_LENGTH, -// () -> { -// calls.add("targetCall"); -// return "resulTestValue"; -// }, -// testArgs); -// -// assertEquals("resulTestValue", result); -// -// assertEquals(7, calls.size()); -// assertEquals("TestPluginOne:before", calls.get(0)); -// assertEquals("TestPluginTwo:before", calls.get(1)); -// assertEquals("TestPluginThree:before", calls.get(2)); -// assertEquals("targetCall", calls.get(3)); -// assertEquals("TestPluginThree:after", calls.get(4)); -// assertEquals("TestPluginTwo:after", calls.get(5)); -// assertEquals("TestPluginOne:after", calls.get(6)); -// } -// -// @Test -// public void testExecuteJdbcCallB() throws Exception { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThree(calls)); -// -// final Properties testProperties = new Properties(); -// -// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; -// -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Object result = -// target.execute( -// String.class, -// Exception.class, -// Connection.class, -// JdbcMethod.BLOB_POSITION, -// () -> { -// calls.add("targetCall"); -// return "resulTestValue"; -// }, -// testArgs); -// -// assertEquals("resulTestValue", result); -// -// assertEquals(5, calls.size()); -// assertEquals("TestPluginOne:before", calls.get(0)); -// assertEquals("TestPluginTwo:before", calls.get(1)); -// assertEquals("targetCall", calls.get(2)); -// assertEquals("TestPluginTwo:after", calls.get(3)); -// assertEquals("TestPluginOne:after", calls.get(4)); -// } -// -// @Test -// public void testExecuteJdbcCallC() throws Exception { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThree(calls)); -// -// final Properties testProperties = new Properties(); -// -// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; -// -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Object result = -// target.execute( -// String.class, -// Exception.class, -// Connection.class, -// JdbcMethod.BLOB_GETBYTES, -// () -> { -// calls.add("targetCall"); -// return "resulTestValue"; -// }, -// testArgs); -// -// assertEquals("resulTestValue", result); -// -// assertEquals(3, calls.size()); -// assertEquals("TestPluginOne:before", calls.get(0)); -// assertEquals("targetCall", calls.get(1)); -// assertEquals("TestPluginOne:after", calls.get(2)); -// } -// -// @Test -// public void testConnect() throws Exception { -// -// final Connection expectedConnection = mock(Connection.class); -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThree(calls, expectedConnection)); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Connection conn = target.connect("any", -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, -// true, null); -// -// assertEquals(expectedConnection, conn); -// assertEquals(4, calls.size()); -// assertEquals("TestPluginOne:before connect", calls.get(0)); -// assertEquals("TestPluginThree:before connect", calls.get(1)); -// assertEquals("TestPluginThree:connection", calls.get(2)); -// assertEquals("TestPluginOne:after connect", calls.get(3)); -// } -// -// @Test -// public void testConnectWithSkipPlugin() throws Exception { -// -// final Connection expectedConnection = mock(Connection.class); -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// final ConnectionPlugin pluginOne = new TestPluginOne(calls); -// testPlugins.add(pluginOne); -// final ConnectionPlugin pluginTwo = new TestPluginTwo(calls); -// testPlugins.add(pluginTwo); -// final ConnectionPlugin pluginThree = new TestPluginThree(calls, expectedConnection); -// testPlugins.add(pluginThree); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Connection conn = target.connect("any", -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, -// true, pluginOne); -// -// assertEquals(expectedConnection, conn); -// assertEquals(2, calls.size()); -// assertEquals("TestPluginThree:before connect", calls.get(0)); -// assertEquals("TestPluginThree:connection", calls.get(1)); -// } -// -// @Test -// public void testForceConnect() throws Exception { -// -// final Connection expectedConnection = mock(Connection.class); -// final ArrayList calls = new ArrayList<>(); -// final ArrayList testPlugins = new ArrayList<>(); -// -// // TestPluginOne is not an AuthenticationConnectionPlugin. -// testPlugins.add(new TestPluginOne(calls)); -// -// // TestPluginTwo is an AuthenticationConnectionPlugin, but it's not subscribed to "forceConnect" method. -// testPlugins.add(new TestPluginTwo(calls)); -// -// // TestPluginThree is an AuthenticationConnectionPlugin, and it's subscribed to "forceConnect" method. -// testPlugins.add(new TestPluginThree(calls, expectedConnection)); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Connection conn = target.forceConnect("any", -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, -// true, -// null); -// -// // Expecting only TestPluginThree to participate in forceConnect(). -// assertEquals(expectedConnection, conn); -// assertEquals(4, calls.size()); -// assertEquals("TestPluginOne:before forceConnect", calls.get(0)); -// assertEquals("TestPluginThree:before forceConnect", calls.get(1)); -// assertEquals("TestPluginThree:forced connection", calls.get(2)); -// assertEquals("TestPluginOne:after forceConnect", calls.get(3)); -// } -// -// @Test -// public void testConnectWithSQLExceptionBefore() { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThrowException(calls, SQLException.class, true)); -// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// assertThrows( -// SQLException.class, -// () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), -// testProperties, true, null)); -// -// assertEquals(2, calls.size()); -// assertEquals("TestPluginOne:before connect", calls.get(0)); -// assertEquals("TestPluginThrowException:before", calls.get(1)); -// } -// -// @Test -// public void testConnectWithSQLExceptionAfter() { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThrowException(calls, SQLException.class, false)); -// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// assertThrows( -// SQLException.class, -// () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), -// testProperties, true, null)); -// -// assertEquals(5, calls.size()); -// assertEquals("TestPluginOne:before connect", calls.get(0)); -// assertEquals("TestPluginThrowException:before", calls.get(1)); -// assertEquals("TestPluginThree:before connect", calls.get(2)); -// assertEquals("TestPluginThree:connection", calls.get(3)); -// assertEquals("TestPluginThrowException:after", calls.get(4)); -// } -// -// @Test -// public void testConnectWithUnexpectedExceptionBefore() { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, true)); -// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Exception ex = -// assertThrows( -// IllegalArgumentException.class, -// () -> target.connect("any", -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), -// testProperties, true, null)); -// -// assertEquals(2, calls.size()); -// assertEquals("TestPluginOne:before connect", calls.get(0)); -// assertEquals("TestPluginThrowException:before", calls.get(1)); -// } -// -// @Test -// public void testConnectWithUnexpectedExceptionAfter() { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, false)); -// testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); -// -// final Exception ex = -// assertThrows( -// IllegalArgumentException.class, -// () -> target.connect("any", -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), -// testProperties, true, null)); -// -// assertEquals(5, calls.size()); -// assertEquals("TestPluginOne:before connect", calls.get(0)); -// assertEquals("TestPluginThrowException:before", calls.get(1)); -// assertEquals("TestPluginThree:before connect", calls.get(2)); -// assertEquals("TestPluginThree:connection", calls.get(3)); -// assertEquals("TestPluginThrowException:after", calls.get(4)); -// } -// -// @Test -// public void testExecuteCachedJdbcCallA() throws Exception { -// -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThree(calls)); -// -// final Properties testProperties = new Properties(); -// -// final Object[] testArgs = new Object[] {10, "arg2", 3.33}; -// -// final ConnectionPluginManager target = Mockito.spy( -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); -// -// Object result = -// target.execute( -// String.class, -// Exception.class, -// Connection.class, -// JdbcMethod.BLOB_LENGTH, -// () -> { -// calls.add("targetCall"); -// return "resulTestValue"; -// }, -// testArgs); -// -// assertEquals("resulTestValue", result); -// -// // The method has been called just once to generate a final lambda and cache it. -// verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); -// -// assertEquals(7, calls.size()); -// assertEquals("TestPluginOne:before", calls.get(0)); -// assertEquals("TestPluginTwo:before", calls.get(1)); -// assertEquals("TestPluginThree:before", calls.get(2)); -// assertEquals("targetCall", calls.get(3)); -// assertEquals("TestPluginThree:after", calls.get(4)); -// assertEquals("TestPluginTwo:after", calls.get(5)); -// assertEquals("TestPluginOne:after", calls.get(6)); -// -// calls.clear(); -// -// result = -// target.execute( -// String.class, -// Exception.class, -// Connection.class, -// JdbcMethod.BLOB_LENGTH, -// () -> { -// calls.add("targetCall"); -// return "anotherResulTestValue"; -// }, -// testArgs); -// -// assertEquals("anotherResulTestValue", result); -// -// // No additional calls to this method occurred. It's still been called once. -// verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); -// -// assertEquals(7, calls.size()); -// assertEquals("TestPluginOne:before", calls.get(0)); -// assertEquals("TestPluginTwo:before", calls.get(1)); -// assertEquals("TestPluginThree:before", calls.get(2)); -// assertEquals("targetCall", calls.get(3)); -// assertEquals("TestPluginThree:after", calls.get(4)); -// assertEquals("TestPluginTwo:after", calls.get(5)); -// assertEquals("TestPluginOne:after", calls.get(6)); -// } -// -// @Test -// public void testForceConnectCachedJdbcCallForceConnect() throws Exception { -// -// final ArrayList calls = new ArrayList<>(); -// final Connection mockConnection = mock(Connection.class); -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThree(calls, mockConnection)); -// -// final HostSpec testHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("test-instance").build(); -// -// final Properties testProperties = new Properties(); -// -// final ConnectionPluginManager target = Mockito.spy( -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); -// -// Object result = target.forceConnect( -// "any", -// testHostSpec, -// testProperties, -// true, -// null); -// -// assertEquals(mockConnection, result); -// -// // The method has been called just once to generate a final lambda and cache it. -// verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); -// -// assertEquals(4, calls.size()); -// assertEquals("TestPluginOne:before forceConnect", calls.get(0)); -// assertEquals("TestPluginThree:before forceConnect", calls.get(1)); -// assertEquals("TestPluginThree:forced connection", calls.get(2)); -// assertEquals("TestPluginOne:after forceConnect", calls.get(3)); -// -// calls.clear(); -// -// result = target.forceConnect( -// "any", -// testHostSpec, -// testProperties, -// true, -// null); -// -// assertEquals(mockConnection, result); -// -// // No additional calls to this method occurred. It's still been called once. -// verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); -// -// assertEquals(4, calls.size()); -// assertEquals("TestPluginOne:before forceConnect", calls.get(0)); -// assertEquals("TestPluginThree:before forceConnect", calls.get(1)); -// assertEquals("TestPluginThree:forced connection", calls.get(2)); -// assertEquals("TestPluginOne:after forceConnect", calls.get(3)); -// } -// -// @Test -// public void testExecuteAgainstOldConnection() throws Exception { -// final ArrayList calls = new ArrayList<>(); -// -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(calls)); -// testPlugins.add(new TestPluginTwo(calls)); -// testPlugins.add(new TestPluginThree(calls)); -// -// final Properties testProperties = new Properties(); -// -// final Connection mockOldConnection = mock(Connection.class); -// final Connection mockCurrentConnection = mock(Connection.class); -// final Statement mockOldStatement = mock(Statement.class); -// final ResultSet mockOldResultSet = mock(ResultSet.class); -// -// when(mockPluginService.getCurrentConnection()).thenReturn(mockCurrentConnection); -// when(mockOldStatement.getConnection()).thenReturn(mockOldConnection); -// when(mockOldResultSet.getStatement()).thenReturn(mockOldStatement); -// -// final ConnectionPluginManager target = -// new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, -// mockPluginService, mockTelemetryFactory); -// -// assertThrows(SQLException.class, -// () -> target.execute(String.class, Exception.class, mockOldConnection, -// JdbcMethod.CALLABLESTATEMENT_GETCONNECTION, () -> "result", null)); -// assertThrows(SQLException.class, -// () -> target.execute(String.class, Exception.class, mockOldStatement, -// JdbcMethod.CALLABLESTATEMENT_GETMORERESULTS, () -> "result", null)); -// assertThrows(SQLException.class, -// () -> target.execute(String.class, Exception.class, mockOldResultSet, -// JdbcMethod.RESULTSET_GETSTATEMENT, () -> "result", null)); -// -// assertDoesNotThrow( -// () -> target.execute(Void.class, SQLException.class, mockOldConnection, -// JdbcMethod.CONNECTION_CLOSE, mockSqlFunction, -// null)); -// assertDoesNotThrow( -// () -> target.execute(Void.class, SQLException.class, mockOldConnection, -// JdbcMethod.CONNECTION_ABORT, mockSqlFunction, -// null)); -// assertDoesNotThrow( -// () -> target.execute(Void.class, SQLException.class, mockOldStatement, -// JdbcMethod.STATEMENT_CLOSE, mockSqlFunction, -// null)); -// assertDoesNotThrow( -// () -> target.execute(Void.class, SQLException.class, mockOldResultSet, -// JdbcMethod.RESULTSET_CLOSE, mockSqlFunction, -// null)); -// } -// -// @Test -// public void testDefaultPlugins() throws SQLException { -// final Properties testProperties = new Properties(); -// -// final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( -// mockConnectionProvider, -// null, -// mockConnectionWrapper, -// mockTelemetryFactory)); -// target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); -// -// assertEquals(4, target.plugins.size()); -// assertEquals(AuroraConnectionTrackerPlugin.class, target.plugins.get(0).getClass()); -// assertEquals(software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin.class, -// target.plugins.get(1).getClass()); -// assertEquals(HostMonitoringConnectionPlugin.class, target.plugins.get(2).getClass()); -// assertEquals(DefaultConnectionPlugin.class, target.plugins.get(3).getClass()); -// } -// -// @Test -// public void testNoWrapperPlugins() throws SQLException { -// final Properties testProperties = new Properties(); -// testProperties.setProperty(PropertyDefinition.PLUGINS.name, ""); -// -// final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( -// mockConnectionProvider, -// null, -// mockConnectionWrapper, -// mockTelemetryFactory)); -// target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); -// -// assertEquals(1, target.plugins.size()); -// } -// -// @Test -// public void testOverridingDefaultPluginsWithPluginCodes() throws SQLException { -// final Properties testProperties = new Properties(); -// testProperties.setProperty("wrapperPlugins", "logQuery"); -// -// final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( -// mockConnectionProvider, -// null, -// mockConnectionWrapper, -// mockTelemetryFactory)); -// target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); -// -// assertEquals(2, target.plugins.size()); -// assertEquals(LogQueryConnectionPlugin.class, target.plugins.get(0).getClass()); -// assertEquals(DefaultConnectionPlugin.class, target.plugins.get(1).getClass()); -// } -// -// @Test -// public void testTwoConnectionsDoNotBlockOneAnother() throws Exception { -// -// final Properties testProperties = new Properties(); -// final ArrayList testPlugins = new ArrayList<>(); -// testPlugins.add(new TestPluginOne(new ArrayList<>())); -// -// final ConnectionProvider mockConnectionProvider1 = Mockito.mock(ConnectionProvider.class); -// final ConnectionWrapper mockConnectionWrapper1 = Mockito.mock(ConnectionWrapper.class); -// final PluginService mockPluginService1 = Mockito.mock(PluginService.class); -// final TelemetryFactory mockTelemetryFactory1 = Mockito.mock(TelemetryFactory.class); -// final Object object1 = new Object(); -// when(mockPluginService1.getTelemetryFactory()).thenReturn(mockTelemetryFactory1); -// when(mockTelemetryFactory1.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory1.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// -// final ConnectionPluginManager pluginManager1 = -// new ConnectionPluginManager(mockConnectionProvider1, -// null, testProperties, testPlugins, mockConnectionWrapper1, -// mockPluginService1, mockTelemetryFactory1); -// -// final ConnectionProvider mockConnectionProvider2 = Mockito.mock(ConnectionProvider.class); -// final ConnectionWrapper mockConnectionWrapper2 = Mockito.mock(ConnectionWrapper.class); -// final PluginService mockPluginService2 = Mockito.mock(PluginService.class); -// final TelemetryFactory mockTelemetryFactory2 = Mockito.mock(TelemetryFactory.class); -// final Object object2 = new Object(); -// when(mockPluginService2.getTelemetryFactory()).thenReturn(mockTelemetryFactory2); -// when(mockTelemetryFactory2.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory2.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// -// final ConnectionPluginManager pluginManager2 = -// new ConnectionPluginManager(mockConnectionProvider2, -// null, testProperties, testPlugins, mockConnectionWrapper2, -// mockPluginService2, mockTelemetryFactory2); -// -// // Imaginary database resource is considered "locked" when latch is 0 -// final CountDownLatch waitForDbResourceLocked = new CountDownLatch(1); -// final ReentrantLock dbResourceLock = new ReentrantLock(); -// final CountDownLatch waitForReleaseDbResourceToProceed = new CountDownLatch(1); -// final AtomicBoolean dbResourceReleased = new AtomicBoolean(false); -// final AtomicBoolean acquireDbResourceLockSuccessful = new AtomicBoolean(false); -// -// CompletableFuture.allOf( -// -// // Thread 1 -// CompletableFuture.runAsync(() -> { -// -// LOGGER.info("thread-1: started"); -// -// WrapperUtils.executeWithPlugins( -// Integer.class, -// pluginManager1, -// object1, -// JdbcMethod.BLOB_POSITION, // any JdbcMethod that locks connection -// () -> { -// dbResourceLock.lock(); -// waitForDbResourceLocked.countDown(); -// LOGGER.info("thread-1: locked"); -// return 1; -// }); -// -// LOGGER.info("thread-1: waiting for thread-2"); -// try { -// waitForReleaseDbResourceToProceed.await(); -// } catch (InterruptedException e) { -// throw new RuntimeException(e); -// } -// LOGGER.info("thread-1: continue"); -// -// WrapperUtils.executeWithPlugins( -// Integer.class, -// pluginManager1, -// object1, -// JdbcMethod.BLOB_TRUNCATE, // any JdbcMethod that locks connection -// () -> { -// dbResourceLock.unlock(); -// dbResourceReleased.set(true); -// LOGGER.info("thread-1: unlocked"); -// return 1; -// }); -// LOGGER.info("thread-1: completed"); -// }), -// -// // Thread 2 -// CompletableFuture.runAsync(() -> { -// -// LOGGER.info("thread-2: started"); -// LOGGER.info("thread-2: waiting for thread-1"); -// try { -// waitForDbResourceLocked.await(); -// } catch (InterruptedException e) { -// throw new RuntimeException(e); -// } -// LOGGER.info("thread-2: continue"); -// -// WrapperUtils.executeWithPlugins( -// Integer.class, -// pluginManager2, -// object2, -// JdbcMethod.BLOB_LENGTH, // any JdbcMethod that locks connection -// () -> { -// waitForReleaseDbResourceToProceed.countDown(); -// LOGGER.info("thread-2: try to acquire a lock"); -// try { -// acquireDbResourceLockSuccessful.set(dbResourceLock.tryLock(5, TimeUnit.SECONDS)); -// } catch (InterruptedException e) { -// throw new RuntimeException(e); -// } -// return 1; -// }); -// LOGGER.info("thread-2: completed"); -// }) -// ).join(); -// -// assertTrue(dbResourceReleased.get()); -// assertTrue(acquireDbResourceLockSuccessful.get()); -// } -// -// @Test -// public void testGetHostSpecByStrategy_givenPluginWithNoSubscriptions_thenThrowsSqlException() throws SQLException { -// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); -// when(mockPlugin.getSubscribedMethods()).thenReturn(Collections.emptySet()); -// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); -// -// final List testPlugins = Collections.singletonList(mockPlugin); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, -// mockPluginService, mockTelemetryFactory); -// -// final HostRole inputHostRole = HostRole.WRITER; -// final String inputStrategy = "someStrategy"; -// -// assertThrows( -// SQLException.class, -// () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); -// } -// -// @Test -// public void testGetHostSpecByStrategy_givenPluginWithDiffSubscription_thenThrowsSqlException() throws SQLException { -// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); -// when(mockPlugin.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); -// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); -// -// final List testPlugins = Collections.singletonList(mockPlugin); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, -// mockPluginService, mockTelemetryFactory); -// -// final HostRole inputHostRole = HostRole.WRITER; -// final String inputStrategy = "someStrategy"; -// -// assertThrows( -// SQLException.class, -// () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); -// } -// -// @Test -// public void testGetHostSpecByStrategy_givenUnsupportedPlugin_thenThrowsSqlException() throws SQLException { -// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); -// when(mockPlugin.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); -// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); -// -// final List testPlugins = Collections.singletonList(mockPlugin); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, -// mockPluginService, mockTelemetryFactory); -// -// final HostRole inputHostRole = HostRole.WRITER; -// final String inputStrategy = "someStrategy"; -// -// assertThrows( -// SQLException.class, -// () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); -// } -// -// @Test -// public void testGetHostSpecByStrategy_givenSupportedSubscribedPlugin_thenThrowsSqlException() throws SQLException { -// final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); -// -// when(mockPlugin.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); -// -// final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("expected-instance").build(); -// when(mockPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); -// -// final List testPlugins = Collections.singletonList(mockPlugin); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, -// mockPluginService, mockTelemetryFactory); -// -// final HostRole inputHostRole = HostRole.WRITER; -// final String inputStrategy = "someStrategy"; -// final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); -// -// verify(mockPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); -// assertEquals(expectedHostSpec, actualHostSpec); -// } -// -// @Test -// public void testGetHostSpecByStrategy_givenMultiplePlugins() throws SQLException { -// final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); -// final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); -// final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); -// final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); -// final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); -// -// final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, -// unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); -// -// when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); -// when(unsubscribedPlugin1.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); -// when(unsupportedSubscribedPlugin0.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); -// when(unsupportedSubscribedPlugin1.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); -// when(supportedSubscribedPlugin.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); -// -// when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); -// when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); -// when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any())) -// .thenThrow(new UnsupportedOperationException()); -// when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any())) -// .thenThrow(new UnsupportedOperationException()); -// -// final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("expected-instance").build(); -// when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, -// mockPluginService, mockTelemetryFactory); -// -// final HostRole inputHostRole = HostRole.WRITER; -// final String inputStrategy = "someStrategy"; -// final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); -// -// verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); -// assertEquals(expectedHostSpec, actualHostSpec); -// } -// -// @Test -// public void testGetHostSpecByStrategy_givenInputHostsAndMultiplePlugins() throws SQLException { -// final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); -// final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); -// final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); -// final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); -// final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); -// -// final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, -// unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); -// -// when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); -// when(unsubscribedPlugin1.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); -// when(unsupportedSubscribedPlugin0.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); -// when(unsupportedSubscribedPlugin1.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); -// when(supportedSubscribedPlugin.getSubscribedMethods()) -// .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); -// -// when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); -// when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); -// when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())) -// .thenThrow(new UnsupportedOperationException()); -// when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())) -// .thenThrow(new UnsupportedOperationException()); -// -// final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("expected-instance").build(); -// when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any(), any())).thenReturn(expectedHostSpec); -// -// final Properties testProperties = new Properties(); -// final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, -// null, testProperties, testPlugins, mockConnectionWrapper, -// mockPluginService, mockTelemetryFactory); -// -// final List inputHosts = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("expected-instance").build()); -// final HostRole inputHostRole = HostRole.WRITER; -// final String inputStrategy = "someStrategy"; -// final HostSpec actualHostSpec = -// connectionPluginManager.getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); -// -// verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); -// assertEquals(expectedHostSpec, actualHostSpec); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; +import java.util.logging.Logger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.mock.TestPluginOne; +import software.amazon.jdbc.mock.TestPluginThree; +import software.amazon.jdbc.mock.TestPluginThrowException; +import software.amazon.jdbc.mock.TestPluginTwo; +import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; +import software.amazon.jdbc.plugin.DefaultConnectionPlugin; +import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; +import software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin; +import software.amazon.jdbc.profile.ConfigurationProfile; +import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.wrapper.ConnectionWrapper; + +public class ConnectionPluginManagerTests { + + private static final Logger LOGGER = Logger.getLogger(ConnectionPluginManagerTests.class.getName()); + + @Mock JdbcCallable mockSqlFunction; + @Mock ConnectionProvider mockConnectionProvider; + @Mock ConnectionWrapper mockConnectionWrapper; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TelemetryContext mockTelemetryContext; + @Mock FullServicesContainer mockServicesContainer; + @Mock PluginService mockPluginService; + @Mock PluginManagerService mockPluginManagerService; + @Mock TargetDriverDialect mockTargetDriverDialect; + + ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); + + private AutoCloseable closeable; + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @BeforeEach + void init() { + closeable = MockitoAnnotations.openMocks(this); + when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); + when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); + when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); + } + + @Test + public void testExecuteJdbcCallA() throws Exception { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls)); + + final Properties testProperties = new Properties(); + + final Object[] testArgs = new Object[] {10, "arg2", 3.33}; + + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Object result = + target.execute( + String.class, + Exception.class, + Connection.class, + JdbcMethod.BLOB_LENGTH, + () -> { + calls.add("targetCall"); + return "resulTestValue"; + }, + testArgs); + + assertEquals("resulTestValue", result); + + assertEquals(7, calls.size()); + assertEquals("TestPluginOne:before", calls.get(0)); + assertEquals("TestPluginTwo:before", calls.get(1)); + assertEquals("TestPluginThree:before", calls.get(2)); + assertEquals("targetCall", calls.get(3)); + assertEquals("TestPluginThree:after", calls.get(4)); + assertEquals("TestPluginTwo:after", calls.get(5)); + assertEquals("TestPluginOne:after", calls.get(6)); + } + + @Test + public void testExecuteJdbcCallB() throws Exception { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls)); + + final Properties testProperties = new Properties(); + + final Object[] testArgs = new Object[] {10, "arg2", 3.33}; + + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Object result = + target.execute( + String.class, + Exception.class, + Connection.class, + JdbcMethod.BLOB_POSITION, + () -> { + calls.add("targetCall"); + return "resulTestValue"; + }, + testArgs); + + assertEquals("resulTestValue", result); + + assertEquals(5, calls.size()); + assertEquals("TestPluginOne:before", calls.get(0)); + assertEquals("TestPluginTwo:before", calls.get(1)); + assertEquals("targetCall", calls.get(2)); + assertEquals("TestPluginTwo:after", calls.get(3)); + assertEquals("TestPluginOne:after", calls.get(4)); + } + + @Test + public void testExecuteJdbcCallC() throws Exception { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls)); + + final Properties testProperties = new Properties(); + + final Object[] testArgs = new Object[] {10, "arg2", 3.33}; + + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Object result = + target.execute( + String.class, + Exception.class, + Connection.class, + JdbcMethod.BLOB_GETBYTES, + () -> { + calls.add("targetCall"); + return "resulTestValue"; + }, + testArgs); + + assertEquals("resulTestValue", result); + + assertEquals(3, calls.size()); + assertEquals("TestPluginOne:before", calls.get(0)); + assertEquals("targetCall", calls.get(1)); + assertEquals("TestPluginOne:after", calls.get(2)); + } + + @Test + public void testConnect() throws Exception { + + final Connection expectedConnection = mock(Connection.class); + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls, expectedConnection)); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Connection conn = target.connect("any", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, + true, null); + + assertEquals(expectedConnection, conn); + assertEquals(4, calls.size()); + assertEquals("TestPluginOne:before connect", calls.get(0)); + assertEquals("TestPluginThree:before connect", calls.get(1)); + assertEquals("TestPluginThree:connection", calls.get(2)); + assertEquals("TestPluginOne:after connect", calls.get(3)); + } + + @Test + public void testConnectWithSkipPlugin() throws Exception { + + final Connection expectedConnection = mock(Connection.class); + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + final ConnectionPlugin pluginOne = new TestPluginOne(calls); + testPlugins.add(pluginOne); + final ConnectionPlugin pluginTwo = new TestPluginTwo(calls); + testPlugins.add(pluginTwo); + final ConnectionPlugin pluginThree = new TestPluginThree(calls, expectedConnection); + testPlugins.add(pluginThree); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Connection conn = target.connect("any", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, + true, pluginOne); + + assertEquals(expectedConnection, conn); + assertEquals(2, calls.size()); + assertEquals("TestPluginThree:before connect", calls.get(0)); + assertEquals("TestPluginThree:connection", calls.get(1)); + } + + @Test + public void testForceConnect() throws Exception { + + final Connection expectedConnection = mock(Connection.class); + final ArrayList calls = new ArrayList<>(); + final ArrayList testPlugins = new ArrayList<>(); + + // TestPluginOne is not an AuthenticationConnectionPlugin. + testPlugins.add(new TestPluginOne(calls)); + + // TestPluginTwo is an AuthenticationConnectionPlugin, but it's not subscribed to "forceConnect" method. + testPlugins.add(new TestPluginTwo(calls)); + + // TestPluginThree is an AuthenticationConnectionPlugin, and it's subscribed to "forceConnect" method. + testPlugins.add(new TestPluginThree(calls, expectedConnection)); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Connection conn = target.forceConnect("any", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, + true, + null); + + // Expecting only TestPluginThree to participate in forceConnect(). + assertEquals(expectedConnection, conn); + assertEquals(4, calls.size()); + assertEquals("TestPluginOne:before forceConnect", calls.get(0)); + assertEquals("TestPluginThree:before forceConnect", calls.get(1)); + assertEquals("TestPluginThree:forced connection", calls.get(2)); + assertEquals("TestPluginOne:after forceConnect", calls.get(3)); + } + + @Test + public void testConnectWithSQLExceptionBefore() { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThrowException(calls, SQLException.class, true)); + testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + assertThrows( + SQLException.class, + () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + testProperties, true, null)); + + assertEquals(2, calls.size()); + assertEquals("TestPluginOne:before connect", calls.get(0)); + assertEquals("TestPluginThrowException:before", calls.get(1)); + } + + @Test + public void testConnectWithSQLExceptionAfter() { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThrowException(calls, SQLException.class, false)); + testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + assertThrows( + SQLException.class, + () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + testProperties, true, null)); + + assertEquals(5, calls.size()); + assertEquals("TestPluginOne:before connect", calls.get(0)); + assertEquals("TestPluginThrowException:before", calls.get(1)); + assertEquals("TestPluginThree:before connect", calls.get(2)); + assertEquals("TestPluginThree:connection", calls.get(3)); + assertEquals("TestPluginThrowException:after", calls.get(4)); + } + + @Test + public void testConnectWithUnexpectedExceptionBefore() { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, true)); + testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Exception ex = + assertThrows( + IllegalArgumentException.class, + () -> target.connect("any", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + testProperties, true, null)); + + assertEquals(2, calls.size()); + assertEquals("TestPluginOne:before connect", calls.get(0)); + assertEquals("TestPluginThrowException:before", calls.get(1)); + } + + @Test + public void testConnectWithUnexpectedExceptionAfter() { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThrowException(calls, IllegalArgumentException.class, false)); + testPlugins.add(new TestPluginThree(calls, mock(Connection.class))); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Exception ex = + assertThrows( + IllegalArgumentException.class, + () -> target.connect("any", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + testProperties, true, null)); + + assertEquals(5, calls.size()); + assertEquals("TestPluginOne:before connect", calls.get(0)); + assertEquals("TestPluginThrowException:before", calls.get(1)); + assertEquals("TestPluginThree:before connect", calls.get(2)); + assertEquals("TestPluginThree:connection", calls.get(3)); + assertEquals("TestPluginThrowException:after", calls.get(4)); + } + + @Test + public void testExecuteCachedJdbcCallA() throws Exception { + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls)); + + final Properties testProperties = new Properties(); + + final Object[] testArgs = new Object[] {10, "arg2", 3.33}; + + final ConnectionPluginManager target = Mockito.spy( + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); + + Object result = + target.execute( + String.class, + Exception.class, + Connection.class, + JdbcMethod.BLOB_LENGTH, + () -> { + calls.add("targetCall"); + return "resulTestValue"; + }, + testArgs); + + assertEquals("resulTestValue", result); + + // The method has been called just once to generate a final lambda and cache it. + verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); + + assertEquals(7, calls.size()); + assertEquals("TestPluginOne:before", calls.get(0)); + assertEquals("TestPluginTwo:before", calls.get(1)); + assertEquals("TestPluginThree:before", calls.get(2)); + assertEquals("targetCall", calls.get(3)); + assertEquals("TestPluginThree:after", calls.get(4)); + assertEquals("TestPluginTwo:after", calls.get(5)); + assertEquals("TestPluginOne:after", calls.get(6)); + + calls.clear(); + + result = + target.execute( + String.class, + Exception.class, + Connection.class, + JdbcMethod.BLOB_LENGTH, + () -> { + calls.add("targetCall"); + return "anotherResulTestValue"; + }, + testArgs); + + assertEquals("anotherResulTestValue", result); + + // No additional calls to this method occurred. It's still been called once. + verify(target, times(1)).makePluginChainFunc(eq(JdbcMethod.BLOB_LENGTH.methodName)); + + assertEquals(7, calls.size()); + assertEquals("TestPluginOne:before", calls.get(0)); + assertEquals("TestPluginTwo:before", calls.get(1)); + assertEquals("TestPluginThree:before", calls.get(2)); + assertEquals("targetCall", calls.get(3)); + assertEquals("TestPluginThree:after", calls.get(4)); + assertEquals("TestPluginTwo:after", calls.get(5)); + assertEquals("TestPluginOne:after", calls.get(6)); + } + + @Test + public void testForceConnectCachedJdbcCallForceConnect() throws Exception { + + final ArrayList calls = new ArrayList<>(); + final Connection mockConnection = mock(Connection.class); + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls, mockConnection)); + + final HostSpec testHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("test-instance").build(); + + final Properties testProperties = new Properties(); + + final ConnectionPluginManager target = Mockito.spy( + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); + + Object result = target.forceConnect( + "any", + testHostSpec, + testProperties, + true, + null); + + assertEquals(mockConnection, result); + + // The method has been called just once to generate a final lambda and cache it. + verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); + + assertEquals(4, calls.size()); + assertEquals("TestPluginOne:before forceConnect", calls.get(0)); + assertEquals("TestPluginThree:before forceConnect", calls.get(1)); + assertEquals("TestPluginThree:forced connection", calls.get(2)); + assertEquals("TestPluginOne:after forceConnect", calls.get(3)); + + calls.clear(); + + result = target.forceConnect( + "any", + testHostSpec, + testProperties, + true, + null); + + assertEquals(mockConnection, result); + + // No additional calls to this method occurred. It's still been called once. + verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); + + assertEquals(4, calls.size()); + assertEquals("TestPluginOne:before forceConnect", calls.get(0)); + assertEquals("TestPluginThree:before forceConnect", calls.get(1)); + assertEquals("TestPluginThree:forced connection", calls.get(2)); + assertEquals("TestPluginOne:after forceConnect", calls.get(3)); + } + + @Test + public void testExecuteAgainstOldConnection() throws Exception { + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls)); + + final Properties testProperties = new Properties(); + + final Connection mockOldConnection = mock(Connection.class); + final Connection mockCurrentConnection = mock(Connection.class); + final Statement mockOldStatement = mock(Statement.class); + final ResultSet mockOldResultSet = mock(ResultSet.class); + + when(mockPluginService.getCurrentConnection()).thenReturn(mockCurrentConnection); + when(mockOldStatement.getConnection()).thenReturn(mockOldConnection); + when(mockOldResultSet.getStatement()).thenReturn(mockOldStatement); + + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, + mockPluginService, mockTelemetryFactory); + + assertThrows(SQLException.class, + () -> target.execute(String.class, Exception.class, mockOldConnection, + JdbcMethod.CALLABLESTATEMENT_GETCONNECTION, () -> "result", null)); + assertThrows(SQLException.class, + () -> target.execute(String.class, Exception.class, mockOldStatement, + JdbcMethod.CALLABLESTATEMENT_GETMORERESULTS, () -> "result", null)); + assertThrows(SQLException.class, + () -> target.execute(String.class, Exception.class, mockOldResultSet, + JdbcMethod.RESULTSET_GETSTATEMENT, () -> "result", null)); + + assertDoesNotThrow( + () -> target.execute(Void.class, SQLException.class, mockOldConnection, + JdbcMethod.CONNECTION_CLOSE, mockSqlFunction, + null)); + assertDoesNotThrow( + () -> target.execute(Void.class, SQLException.class, mockOldConnection, + JdbcMethod.CONNECTION_ABORT, mockSqlFunction, + null)); + assertDoesNotThrow( + () -> target.execute(Void.class, SQLException.class, mockOldStatement, + JdbcMethod.STATEMENT_CLOSE, mockSqlFunction, + null)); + assertDoesNotThrow( + () -> target.execute(Void.class, SQLException.class, mockOldResultSet, + JdbcMethod.RESULTSET_CLOSE, mockSqlFunction, + null)); + } + + @Test + public void testDefaultPlugins() throws SQLException { + final Properties testProperties = new Properties(); + + final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( + mockConnectionProvider, + null, + mockConnectionWrapper, + mockTelemetryFactory)); + target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); + + assertEquals(4, target.plugins.size()); + assertEquals(AuroraConnectionTrackerPlugin.class, target.plugins.get(0).getClass()); + assertEquals(software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin.class, + target.plugins.get(1).getClass()); + assertEquals(HostMonitoringConnectionPlugin.class, target.plugins.get(2).getClass()); + assertEquals(DefaultConnectionPlugin.class, target.plugins.get(3).getClass()); + } + + @Test + public void testNoWrapperPlugins() throws SQLException { + final Properties testProperties = new Properties(); + testProperties.setProperty(PropertyDefinition.PLUGINS.name, ""); + + final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( + mockConnectionProvider, + null, + mockConnectionWrapper, + mockTelemetryFactory)); + target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); + + assertEquals(1, target.plugins.size()); + } + + @Test + public void testOverridingDefaultPluginsWithPluginCodes() throws SQLException { + final Properties testProperties = new Properties(); + testProperties.setProperty("wrapperPlugins", "logQuery"); + + final ConnectionPluginManager target = Mockito.spy(new ConnectionPluginManager( + mockConnectionProvider, + null, + mockConnectionWrapper, + mockTelemetryFactory)); + target.init(mockServicesContainer, testProperties, mockPluginManagerService, configurationProfile); + + assertEquals(2, target.plugins.size()); + assertEquals(LogQueryConnectionPlugin.class, target.plugins.get(0).getClass()); + assertEquals(DefaultConnectionPlugin.class, target.plugins.get(1).getClass()); + } + + @Test + public void testTwoConnectionsDoNotBlockOneAnother() throws Exception { + + final Properties testProperties = new Properties(); + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(new ArrayList<>())); + + final ConnectionProvider mockConnectionProvider1 = Mockito.mock(ConnectionProvider.class); + final ConnectionWrapper mockConnectionWrapper1 = Mockito.mock(ConnectionWrapper.class); + final PluginService mockPluginService1 = Mockito.mock(PluginService.class); + final TelemetryFactory mockTelemetryFactory1 = Mockito.mock(TelemetryFactory.class); + final Object object1 = new Object(); + when(mockPluginService1.getTelemetryFactory()).thenReturn(mockTelemetryFactory1); + when(mockTelemetryFactory1.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory1.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + + final ConnectionPluginManager pluginManager1 = + new ConnectionPluginManager(mockConnectionProvider1, + null, testProperties, testPlugins, mockConnectionWrapper1, + mockPluginService1, mockTelemetryFactory1); + + final ConnectionProvider mockConnectionProvider2 = Mockito.mock(ConnectionProvider.class); + final ConnectionWrapper mockConnectionWrapper2 = Mockito.mock(ConnectionWrapper.class); + final PluginService mockPluginService2 = Mockito.mock(PluginService.class); + final TelemetryFactory mockTelemetryFactory2 = Mockito.mock(TelemetryFactory.class); + final Object object2 = new Object(); + when(mockPluginService2.getTelemetryFactory()).thenReturn(mockTelemetryFactory2); + when(mockTelemetryFactory2.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory2.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + + final ConnectionPluginManager pluginManager2 = + new ConnectionPluginManager(mockConnectionProvider2, + null, testProperties, testPlugins, mockConnectionWrapper2, + mockPluginService2, mockTelemetryFactory2); + + // Imaginary database resource is considered "locked" when latch is 0 + final CountDownLatch waitForDbResourceLocked = new CountDownLatch(1); + final ReentrantLock dbResourceLock = new ReentrantLock(); + final CountDownLatch waitForReleaseDbResourceToProceed = new CountDownLatch(1); + final AtomicBoolean dbResourceReleased = new AtomicBoolean(false); + final AtomicBoolean acquireDbResourceLockSuccessful = new AtomicBoolean(false); + + CompletableFuture.allOf( + + // Thread 1 + CompletableFuture.runAsync(() -> { + + LOGGER.info("thread-1: started"); + + WrapperUtils.executeWithPlugins( + Integer.class, + pluginManager1, + object1, + JdbcMethod.BLOB_POSITION, // any JdbcMethod that locks connection + () -> { + dbResourceLock.lock(); + waitForDbResourceLocked.countDown(); + LOGGER.info("thread-1: locked"); + return 1; + }); + + LOGGER.info("thread-1: waiting for thread-2"); + try { + waitForReleaseDbResourceToProceed.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + LOGGER.info("thread-1: continue"); + + WrapperUtils.executeWithPlugins( + Integer.class, + pluginManager1, + object1, + JdbcMethod.BLOB_TRUNCATE, // any JdbcMethod that locks connection + () -> { + dbResourceLock.unlock(); + dbResourceReleased.set(true); + LOGGER.info("thread-1: unlocked"); + return 1; + }); + LOGGER.info("thread-1: completed"); + }), + + // Thread 2 + CompletableFuture.runAsync(() -> { + + LOGGER.info("thread-2: started"); + LOGGER.info("thread-2: waiting for thread-1"); + try { + waitForDbResourceLocked.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + LOGGER.info("thread-2: continue"); + + WrapperUtils.executeWithPlugins( + Integer.class, + pluginManager2, + object2, + JdbcMethod.BLOB_LENGTH, // any JdbcMethod that locks connection + () -> { + waitForReleaseDbResourceToProceed.countDown(); + LOGGER.info("thread-2: try to acquire a lock"); + try { + acquireDbResourceLockSuccessful.set(dbResourceLock.tryLock(5, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return 1; + }); + LOGGER.info("thread-2: completed"); + }) + ).join(); + + assertTrue(dbResourceReleased.get()); + assertTrue(acquireDbResourceLockSuccessful.get()); + } + + @Test + public void testGetHostSpecByStrategy_givenPluginWithNoSubscriptions_thenThrowsSqlException() throws SQLException { + final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); + when(mockPlugin.getSubscribedMethods()).thenReturn(Collections.emptySet()); + when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); + + final List testPlugins = Collections.singletonList(mockPlugin); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, + mockPluginService, mockTelemetryFactory); + + final HostRole inputHostRole = HostRole.WRITER; + final String inputStrategy = "someStrategy"; + + assertThrows( + SQLException.class, + () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); + } + + @Test + public void testGetHostSpecByStrategy_givenPluginWithDiffSubscription_thenThrowsSqlException() throws SQLException { + final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); + when(mockPlugin.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); + when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); + + final List testPlugins = Collections.singletonList(mockPlugin); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, + mockPluginService, mockTelemetryFactory); + + final HostRole inputHostRole = HostRole.WRITER; + final String inputStrategy = "someStrategy"; + + assertThrows( + SQLException.class, + () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); + } + + @Test + public void testGetHostSpecByStrategy_givenUnsupportedPlugin_thenThrowsSqlException() throws SQLException { + final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); + when(mockPlugin.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); + when(mockPlugin.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); + + final List testPlugins = Collections.singletonList(mockPlugin); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, + mockPluginService, mockTelemetryFactory); + + final HostRole inputHostRole = HostRole.WRITER; + final String inputStrategy = "someStrategy"; + + assertThrows( + SQLException.class, + () -> connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy)); + } + + @Test + public void testGetHostSpecByStrategy_givenSupportedSubscribedPlugin_thenThrowsSqlException() throws SQLException { + final ConnectionPlugin mockPlugin = mock(ConnectionPlugin.class); + + when(mockPlugin.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); + + final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("expected-instance").build(); + when(mockPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); + + final List testPlugins = Collections.singletonList(mockPlugin); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, + mockPluginService, mockTelemetryFactory); + + final HostRole inputHostRole = HostRole.WRITER; + final String inputStrategy = "someStrategy"; + final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); + + verify(mockPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); + assertEquals(expectedHostSpec, actualHostSpec); + } + + @Test + public void testGetHostSpecByStrategy_givenMultiplePlugins() throws SQLException { + final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); + final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); + final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); + final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); + final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); + + final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, + unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); + + when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); + when(unsubscribedPlugin1.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); + when(unsupportedSubscribedPlugin0.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); + when(unsupportedSubscribedPlugin1.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); + when(supportedSubscribedPlugin.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); + + when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); + when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any())).thenThrow(new UnsupportedOperationException()); + when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any())) + .thenThrow(new UnsupportedOperationException()); + when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any())) + .thenThrow(new UnsupportedOperationException()); + + final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("expected-instance").build(); + when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any())).thenReturn(expectedHostSpec); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, + mockPluginService, mockTelemetryFactory); + + final HostRole inputHostRole = HostRole.WRITER; + final String inputStrategy = "someStrategy"; + final HostSpec actualHostSpec = connectionPluginManager.getHostSpecByStrategy(inputHostRole, inputStrategy); + + verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHostRole, inputStrategy); + assertEquals(expectedHostSpec, actualHostSpec); + } + + @Test + public void testGetHostSpecByStrategy_givenInputHostsAndMultiplePlugins() throws SQLException { + final ConnectionPlugin unsubscribedPlugin0 = mock(ConnectionPlugin.class); + final ConnectionPlugin unsupportedSubscribedPlugin0 = mock(ConnectionPlugin.class); + final ConnectionPlugin unsubscribedPlugin1 = mock(ConnectionPlugin.class); + final ConnectionPlugin unsupportedSubscribedPlugin1 = mock(ConnectionPlugin.class); + final ConnectionPlugin supportedSubscribedPlugin = mock(ConnectionPlugin.class); + + final List testPlugins = Arrays.asList(unsubscribedPlugin0, unsupportedSubscribedPlugin0, + unsubscribedPlugin1, unsupportedSubscribedPlugin1, supportedSubscribedPlugin); + + when(unsubscribedPlugin0.getSubscribedMethods()).thenReturn(Collections.emptySet()); + when(unsubscribedPlugin1.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.CONNECT.methodName))); + when(unsupportedSubscribedPlugin0.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.ALL.methodName))); + when(unsupportedSubscribedPlugin1.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); + when(supportedSubscribedPlugin.getSubscribedMethods()) + .thenReturn(new HashSet<>(Collections.singletonList(JdbcMethod.GETHOSTSPECBYSTRATEGY.methodName))); + + when(unsubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); + when(unsubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())).thenThrow(new UnsupportedOperationException()); + when(unsupportedSubscribedPlugin0.getHostSpecByStrategy(any(), any(), any())) + .thenThrow(new UnsupportedOperationException()); + when(unsupportedSubscribedPlugin1.getHostSpecByStrategy(any(), any(), any())) + .thenThrow(new UnsupportedOperationException()); + + final HostSpec expectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("expected-instance").build(); + when(supportedSubscribedPlugin.getHostSpecByStrategy(any(), any(), any())).thenReturn(expectedHostSpec); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager connectionPluginManager = new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, + mockPluginService, mockTelemetryFactory); + + final List inputHosts = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("expected-instance").build()); + final HostRole inputHostRole = HostRole.WRITER; + final String inputStrategy = "someStrategy"; + final HostSpec actualHostSpec = + connectionPluginManager.getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); + + verify(supportedSubscribedPlugin, times(1)).getHostSpecByStrategy(inputHosts, inputHostRole, inputStrategy); + assertEquals(expectedHostSpec, actualHostSpec); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java index 735cebd67..3b47f12bb 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java @@ -1,289 +1,289 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.ResultSetMetaData; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.ArrayList; -// import java.util.Properties; -// import java.util.stream.Stream; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Disabled; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.Arguments; -// import org.junit.jupiter.params.provider.MethodSource; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.dialect.AuroraMysqlDialect; -// import software.amazon.jdbc.dialect.AuroraPgDialect; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.dialect.DialectManager; -// import software.amazon.jdbc.dialect.MariaDbDialect; -// import software.amazon.jdbc.dialect.MysqlDialect; -// import software.amazon.jdbc.dialect.PgDialect; -// import software.amazon.jdbc.dialect.RdsMultiAzDbClusterMysqlDialect; -// import software.amazon.jdbc.dialect.RdsMultiAzDbClusterPgDialect; -// import software.amazon.jdbc.dialect.RdsMysqlDialect; -// import software.amazon.jdbc.dialect.RdsPgDialect; -// import software.amazon.jdbc.exceptions.ExceptionManager; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.storage.StorageService; -// -// public class DialectDetectionTests { -// private static final String LOCALHOST = "localhost"; -// private static final String RDS_DATABASE = "database-1.xyz.us-east-2.rds.amazonaws.com"; -// private static final String RDS_AURORA_DATABASE = "database-2.cluster-xyz.us-east-2.rds.amazonaws.com"; -// private static final String MYSQL_PROTOCOL = "jdbc:mysql://"; -// private static final String PG_PROTOCOL = "jdbc:postgresql://"; -// private static final String MARIA_PROTOCOL = "jdbc:mariadb://"; -// private final Properties props = new Properties(); -// private AutoCloseable closeable; -// @Mock private FullServicesContainer mockServicesContainer; -// @Mock private HostListProviderService mockHostListProviderService; -// @Mock private StorageService mockStorageService; -// @Mock private Connection mockConnection; -// @Mock private Statement mockStatement; -// @Mock private ResultSet mockSuccessResultSet; -// @Mock private ResultSet mockFailResultSet; -// @Mock private HostSpec mockHost; -// @Mock private ConnectionPluginManager mockPluginManager; -// @Mock private TargetDriverDialect mockTargetDriverDialect; -// @Mock private ResultSetMetaData mockResultSetMetaData; -// -// @BeforeEach -// void setUp() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// when(this.mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); -// when(this.mockServicesContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); -// when(this.mockServicesContainer.getStorageService()).thenReturn(mockStorageService); -// when(this.mockConnection.createStatement()).thenReturn(this.mockStatement); -// when(this.mockHost.getUrl()).thenReturn("url"); -// when(this.mockFailResultSet.next()).thenReturn(false); -// mockPluginManager.plugins = new ArrayList<>(); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// DialectManager.resetEndpointCache(); -// } -// -// PluginServiceImpl getPluginService(String host, String protocol) throws SQLException { -// PluginServiceImpl pluginService = spy( -// new PluginServiceImpl( -// mockServicesContainer, -// new ExceptionManager(), -// props, -// protocol + host, -// protocol, -// null, -// mockTargetDriverDialect, -// null, -// null)); -// -// when(this.mockServicesContainer.getHostListProviderService()).thenReturn(pluginService); -// return pluginService; -// } -// -// @ParameterizedTest -// @MethodSource("getInitialDialectArguments") -// public void testInitialDialectDetection(String protocol, String host, Object expectedDialect) throws SQLException { -// final DialectManager dialectManager = new DialectManager(this.getPluginService(host, protocol)); -// final Dialect dialect = dialectManager.getDialect(protocol, host, new Properties()); -// assertEquals(expectedDialect, dialect.getClass()); -// } -// -// static Stream getInitialDialectArguments() { -// return Stream.of( -// Arguments.of(MYSQL_PROTOCOL, LOCALHOST, MysqlDialect.class), -// Arguments.of(MYSQL_PROTOCOL, RDS_DATABASE, RdsMysqlDialect.class), -// Arguments.of(MYSQL_PROTOCOL, RDS_AURORA_DATABASE, AuroraMysqlDialect.class), -// Arguments.of(PG_PROTOCOL, LOCALHOST, PgDialect.class), -// Arguments.of(PG_PROTOCOL, RDS_DATABASE, RdsPgDialect.class), -// Arguments.of(PG_PROTOCOL, RDS_AURORA_DATABASE, AuroraPgDialect.class), -// Arguments.of(MARIA_PROTOCOL, LOCALHOST, MariaDbDialect.class), -// Arguments.of(MARIA_PROTOCOL, RDS_DATABASE, MariaDbDialect.class), -// Arguments.of(MARIA_PROTOCOL, RDS_AURORA_DATABASE, MariaDbDialect.class) -// ); -// } -// -// @Test -// void testUpdateDialectMysqlUnchanged() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(MysqlDialect.class, target.dialect.getClass()); -// } -// -// @Test -// void testUpdateDialectMysqlToRds() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); -// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); -// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); -// when(mockSuccessResultSet.getString(2)).thenReturn( -// "Source distribution", "Source distribution", ""); -// when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); -// when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); -// when(mockFailResultSet.next()).thenReturn(false); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); -// } -// -// @Test -// @Disabled -// // TODO: fix me: need to split this test into two separate tests: -// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect -// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() -// void testUpdateDialectMysqlToTaz() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); -// when(mockSuccessResultSet.next()).thenReturn(true); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); -// } -// -// @Test -// void testUpdateDialectMysqlToAurora() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); -// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); -// when(mockSuccessResultSet.next()).thenReturn(true, false); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); -// when(mockServicesContainer.getPluginService()).thenReturn(target); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); -// } -// -// @Test -// void testUpdateDialectPgUnchanged() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); -// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(PgDialect.class, target.dialect.getClass()); -// } -// -// @Test -// void testUpdateDialectPgToRds() throws SQLException { -// when(mockStatement.executeQuery(any())) -// .thenReturn(mockSuccessResultSet, mockFailResultSet, mockFailResultSet, mockSuccessResultSet); -// when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); -// when(mockSuccessResultSet.getBoolean("rds_tools")).thenReturn(true); -// when(mockSuccessResultSet.getBoolean("aurora_stat_utils")).thenReturn(false); -// when(mockSuccessResultSet.next()).thenReturn(true); -// when(mockFailResultSet.next()).thenReturn(false); -// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(RdsPgDialect.class, target.dialect.getClass()); -// } -// -// @Test -// @Disabled -// // TODO: fix me: need to split this test into two separate tests: -// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect -// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() -// void testUpdateDialectPgToTaz() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); -// when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); -// when(mockSuccessResultSet.next()).thenReturn(true); -// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(RdsMultiAzDbClusterPgDialect.class, target.dialect.getClass()); -// } -// -// @Test -// @Disabled -// // TODO: fix me: need to split this test into two separate tests: -// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect -// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() -// void testUpdateDialectPgToAurora() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); -// when(mockSuccessResultSet.next()).thenReturn(true); -// when(mockSuccessResultSet.getBoolean(any())).thenReturn(true); -// final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(AuroraPgDialect.class, target.dialect.getClass()); -// } -// -// @Test -// void testUpdateDialectMariaUnchanged() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(MariaDbDialect.class, target.dialect.getClass()); -// } -// -// @Test -// void testUpdateDialectMariaToMysqlRds() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); -// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); -// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); -// when(mockSuccessResultSet.getString(2)).thenReturn( -// "Source distribution", "Source distribution", ""); -// when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); -// when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); -// when(mockFailResultSet.next()).thenReturn(false); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); -// } -// -// @Test -// @Disabled -// // TODO: fix me: need to split this test into two separate tests: -// // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect -// // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() -// void testUpdateDialectMariaToMysqlTaz() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(RdsMultiAzDbClusterMysqlDialect.class, target.dialect.getClass()); -// } -// -// @Test -// void testUpdateDialectMariaToMysqlAurora() throws SQLException { -// when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); -// when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); -// when(mockSuccessResultSet.next()).thenReturn(true, false); -// final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); -// when(mockServicesContainer.getPluginService()).thenReturn(target); -// target.setInitialConnectionHostSpec(mockHost); -// target.updateDialect(mockConnection); -// assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Properties; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.dialect.AuroraMysqlDialect; +import software.amazon.jdbc.dialect.AuroraPgDialect; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.dialect.DialectManager; +import software.amazon.jdbc.dialect.MariaDbDialect; +import software.amazon.jdbc.dialect.MysqlDialect; +import software.amazon.jdbc.dialect.PgDialect; +import software.amazon.jdbc.dialect.RdsMultiAzDbClusterMysqlDialect; +import software.amazon.jdbc.dialect.RdsMultiAzDbClusterPgDialect; +import software.amazon.jdbc.dialect.RdsMysqlDialect; +import software.amazon.jdbc.dialect.RdsPgDialect; +import software.amazon.jdbc.exceptions.ExceptionManager; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.storage.StorageService; + +public class DialectDetectionTests { + private static final String LOCALHOST = "localhost"; + private static final String RDS_DATABASE = "database-1.xyz.us-east-2.rds.amazonaws.com"; + private static final String RDS_AURORA_DATABASE = "database-2.cluster-xyz.us-east-2.rds.amazonaws.com"; + private static final String MYSQL_PROTOCOL = "jdbc:mysql://"; + private static final String PG_PROTOCOL = "jdbc:postgresql://"; + private static final String MARIA_PROTOCOL = "jdbc:mariadb://"; + private final Properties props = new Properties(); + private AutoCloseable closeable; + @Mock private FullServicesContainer mockServicesContainer; + @Mock private HostListProviderService mockHostListProviderService; + @Mock private StorageService mockStorageService; + @Mock private Connection mockConnection; + @Mock private Statement mockStatement; + @Mock private ResultSet mockSuccessResultSet; + @Mock private ResultSet mockFailResultSet; + @Mock private HostSpec mockHost; + @Mock private ConnectionPluginManager mockPluginManager; + @Mock private TargetDriverDialect mockTargetDriverDialect; + @Mock private ResultSetMetaData mockResultSetMetaData; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + when(this.mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); + when(this.mockServicesContainer.getConnectionPluginManager()).thenReturn(mockPluginManager); + when(this.mockServicesContainer.getStorageService()).thenReturn(mockStorageService); + when(this.mockConnection.createStatement()).thenReturn(this.mockStatement); + when(this.mockHost.getUrl()).thenReturn("url"); + when(this.mockFailResultSet.next()).thenReturn(false); + mockPluginManager.plugins = new ArrayList<>(); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + DialectManager.resetEndpointCache(); + } + + PluginServiceImpl getPluginService(String host, String protocol) throws SQLException { + PluginServiceImpl pluginService = spy( + new PluginServiceImpl( + mockServicesContainer, + new ExceptionManager(), + props, + protocol + host, + protocol, + null, + mockTargetDriverDialect, + null, + null)); + + when(this.mockServicesContainer.getHostListProviderService()).thenReturn(pluginService); + return pluginService; + } + + @ParameterizedTest + @MethodSource("getInitialDialectArguments") + public void testInitialDialectDetection(String protocol, String host, Object expectedDialect) throws SQLException { + final DialectManager dialectManager = new DialectManager(this.getPluginService(host, protocol)); + final Dialect dialect = dialectManager.getDialect(protocol, host, new Properties()); + assertEquals(expectedDialect, dialect.getClass()); + } + + static Stream getInitialDialectArguments() { + return Stream.of( + Arguments.of(MYSQL_PROTOCOL, LOCALHOST, MysqlDialect.class), + Arguments.of(MYSQL_PROTOCOL, RDS_DATABASE, RdsMysqlDialect.class), + Arguments.of(MYSQL_PROTOCOL, RDS_AURORA_DATABASE, AuroraMysqlDialect.class), + Arguments.of(PG_PROTOCOL, LOCALHOST, PgDialect.class), + Arguments.of(PG_PROTOCOL, RDS_DATABASE, RdsPgDialect.class), + Arguments.of(PG_PROTOCOL, RDS_AURORA_DATABASE, AuroraPgDialect.class), + Arguments.of(MARIA_PROTOCOL, LOCALHOST, MariaDbDialect.class), + Arguments.of(MARIA_PROTOCOL, RDS_DATABASE, MariaDbDialect.class), + Arguments.of(MARIA_PROTOCOL, RDS_AURORA_DATABASE, MariaDbDialect.class) + ); + } + + @Test + void testUpdateDialectMysqlUnchanged() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); + final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(MysqlDialect.class, target.dialect.getClass()); + } + + @Test + void testUpdateDialectMysqlToRds() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); + when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); + when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); + when(mockSuccessResultSet.getString(2)).thenReturn( + "Source distribution", "Source distribution", ""); + when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); + when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); + when(mockFailResultSet.next()).thenReturn(false); + final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); + } + + @Test + @Disabled + // TODO: fix me: need to split this test into two separate tests: + // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect + // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() + void testUpdateDialectMysqlToTaz() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); + when(mockSuccessResultSet.next()).thenReturn(true); + final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); + } + + @Test + void testUpdateDialectMysqlToAurora() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); + when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); + when(mockSuccessResultSet.next()).thenReturn(true, false); + final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); + when(mockServicesContainer.getPluginService()).thenReturn(target); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); + } + + @Test + void testUpdateDialectPgUnchanged() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); + final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(PgDialect.class, target.dialect.getClass()); + } + + @Test + void testUpdateDialectPgToRds() throws SQLException { + when(mockStatement.executeQuery(any())) + .thenReturn(mockSuccessResultSet, mockFailResultSet, mockFailResultSet, mockSuccessResultSet); + when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); + when(mockSuccessResultSet.getBoolean("rds_tools")).thenReturn(true); + when(mockSuccessResultSet.getBoolean("aurora_stat_utils")).thenReturn(false); + when(mockSuccessResultSet.next()).thenReturn(true); + when(mockFailResultSet.next()).thenReturn(false); + final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(RdsPgDialect.class, target.dialect.getClass()); + } + + @Test + @Disabled + // TODO: fix me: need to split this test into two separate tests: + // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect + // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() + void testUpdateDialectPgToTaz() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); + when(mockSuccessResultSet.getBoolean(any())).thenReturn(false); + when(mockSuccessResultSet.next()).thenReturn(true); + final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(RdsMultiAzDbClusterPgDialect.class, target.dialect.getClass()); + } + + @Test + @Disabled + // TODO: fix me: need to split this test into two separate tests: + // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect + // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() + void testUpdateDialectPgToAurora() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockSuccessResultSet); + when(mockSuccessResultSet.next()).thenReturn(true); + when(mockSuccessResultSet.getBoolean(any())).thenReturn(true); + final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(AuroraPgDialect.class, target.dialect.getClass()); + } + + @Test + void testUpdateDialectMariaUnchanged() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); + final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(MariaDbDialect.class, target.dialect.getClass()); + } + + @Test + void testUpdateDialectMariaToMysqlRds() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); + when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'version_comment'")).thenReturn(mockSuccessResultSet); + when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'report_host'")).thenReturn(mockSuccessResultSet); + when(mockSuccessResultSet.getString(2)).thenReturn( + "Source distribution", "Source distribution", ""); + when(mockSuccessResultSet.next()).thenReturn(true, false, true, true); + when(mockSuccessResultSet.getMetaData()).thenReturn(mockResultSetMetaData); + when(mockFailResultSet.next()).thenReturn(false); + final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); + } + + @Test + @Disabled + // TODO: fix me: need to split this test into two separate tests: + // 1) test DialectManager.getDialect() to return RdsMultiAzDbClusterMysqlDialect + // 2) test PluginServiceImpl.updateDialect() with mocked DialectManager.getDialect() + void testUpdateDialectMariaToMysqlTaz() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet, mockSuccessResultSet); + final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(RdsMultiAzDbClusterMysqlDialect.class, target.dialect.getClass()); + } + + @Test + void testUpdateDialectMariaToMysqlAurora() throws SQLException { + when(mockStatement.executeQuery(any())).thenReturn(mockFailResultSet); + when(mockStatement.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")).thenReturn(mockSuccessResultSet); + when(mockSuccessResultSet.next()).thenReturn(true, false); + final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); + when(mockServicesContainer.getPluginService()).thenReturn(target); + target.setInitialConnectionHostSpec(mockHost); + target.updateDialect(mockConnection); + assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java index 05f700394..6e5844ccf 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java @@ -1,247 +1,247 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertFalse; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import com.zaxxer.hikari.HikariConfig; -// import com.zaxxer.hikari.HikariDataSource; -// import com.zaxxer.hikari.HikariPoolMXBean; -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.ArrayList; -// import java.util.Collections; -// import java.util.HashSet; -// import java.util.List; -// import java.util.Properties; -// import java.util.Set; -// import java.util.concurrent.TimeUnit; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.targetdriverdialect.ConnectInfo; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.Pair; -// import software.amazon.jdbc.util.storage.SlidingExpirationCache; -// -// class HikariPooledConnectionProviderTest { -// @Mock Connection mockConnection; -// @Mock HikariDataSource mockDataSource; -// @Mock HostSpec mockHostSpec; -// @Mock HikariConfig mockConfig; -// @Mock Dialect mockDialect; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// @Mock HikariDataSource dsWithNoConnections; -// @Mock HikariDataSource dsWith1Connection; -// @Mock HikariDataSource dsWith2Connections; -// @Mock HikariPoolMXBean mxBeanWithNoConnections; -// @Mock HikariPoolMXBean mxBeanWith1Connection; -// @Mock HikariPoolMXBean mxBeanWith2Connections; -// private static final String LEAST_CONNECTIONS = "leastConnections"; -// private final int port = 5432; -// private final String user1 = "user1"; -// private final String user2 = "user2"; -// private final String password = "password"; -// private final String db = "mydb"; -// private final String writerUrlNoConnections = "writerWithNoConnections.XYZ.us-east-1.rds.amazonaws.com"; -// private final HostSpec writerHostNoConnections = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host(writerUrlNoConnections).port(port).role(HostRole.WRITER).build(); -// private final String readerUrl1Connection = "readerWith1connection.XYZ.us-east-1.rds.amazonaws.com"; -// private final HostSpec readerHost1Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host(readerUrl1Connection).port(port).role(HostRole.READER).build(); -// private final String readerUrl2Connection = "readerWith2connection.XYZ.us-east-1.rds.amazonaws.com"; -// private final HostSpec readerHost2Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host(readerUrl2Connection).port(port).role(HostRole.READER).build(); -// private final String protocol = "protocol://"; -// -// private final Properties defaultProps = getDefaultProps(); -// private final List testHosts = getTestHosts(); -// private HikariPooledConnectionProvider provider; -// -// private AutoCloseable closeable; -// -// private List getTestHosts() { -// List hosts = new ArrayList<>(); -// hosts.add(writerHostNoConnections); -// hosts.add(readerHost1Connection); -// hosts.add(readerHost2Connection); -// return hosts; -// } -// -// private Properties getDefaultProps() { -// Properties props = new Properties(); -// props.setProperty(PropertyDefinition.USER.name, user1); -// props.setProperty(PropertyDefinition.PASSWORD.name, password); -// props.setProperty(PropertyDefinition.DATABASE.name, db); -// return props; -// } -// -// @BeforeEach -// void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// when(mockDataSource.getConnection()).thenReturn(mockConnection); -// when(mockConnection.isValid(any(Integer.class))).thenReturn(true); -// when(dsWithNoConnections.getHikariPoolMXBean()).thenReturn(mxBeanWithNoConnections); -// when(mxBeanWithNoConnections.getActiveConnections()).thenReturn(0); -// when(dsWith1Connection.getHikariPoolMXBean()).thenReturn(mxBeanWith1Connection); -// when(mxBeanWith1Connection.getActiveConnections()).thenReturn(1); -// when(dsWith2Connections.getHikariPoolMXBean()).thenReturn(mxBeanWith2Connections); -// when(mxBeanWith2Connections.getActiveConnections()).thenReturn(2); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// if (provider != null) { -// provider.releaseResources(); -// } -// closeable.close(); -// } -// -// @Test -// void testConnectWithDefaultMapping() throws SQLException { -// when(mockHostSpec.getUrl()).thenReturn("url"); -// final Set expectedUrls = new HashSet<>(Collections.singletonList("url")); -// final Set expectedKeys = new HashSet<>( -// Collections.singletonList(Pair.create("url", user1))); -// -// provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); -// -// doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); -// doReturn(new ConnectInfo("url", new Properties())) -// .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); -// -// Properties props = new Properties(); -// props.setProperty(PropertyDefinition.USER.name, user1); -// props.setProperty(PropertyDefinition.PASSWORD.name, password); -// try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { -// assertEquals(mockConnection, conn); -// assertEquals(1, provider.getHostCount()); -// final Set hosts = provider.getHosts(); -// assertEquals(expectedUrls, hosts); -// final Set keys = provider.getKeys(); -// assertEquals(expectedKeys, keys); -// } -// } -// -// @Test -// void testConnectWithCustomMapping() throws SQLException { -// when(mockHostSpec.getUrl()).thenReturn("url"); -// final Set expectedKeys = new HashSet<>( -// Collections.singletonList(Pair.create("url", "url+someUniqueKey"))); -// -// provider = spy(new HikariPooledConnectionProvider( -// (hostSpec, properties) -> mockConfig, -// (hostSpec, properties) -> hostSpec.getUrl() + "+someUniqueKey")); -// -// doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); -// -// Properties props = new Properties(); -// props.setProperty(PropertyDefinition.USER.name, user1); -// props.setProperty(PropertyDefinition.PASSWORD.name, password); -// try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { -// assertEquals(mockConnection, conn); -// assertEquals(1, provider.getHostCount()); -// final Set keys = provider.getKeys(); -// assertEquals(expectedKeys, keys); -// } -// } -// -// @Test -// public void testAcceptsUrl() { -// final String clusterUrl = "my-database.cluster-XYZ.us-east-1.rds.amazonaws.com"; -// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); -// -// assertTrue( -// provider.acceptsUrl(protocol, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(readerUrl2Connection).build(), -// defaultProps)); -// assertFalse( -// provider.acceptsUrl(protocol, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(clusterUrl).build(), defaultProps)); -// } -// -// @Test -// public void testRandomStrategy() throws SQLException { -// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); -// provider.setDatabasePools(getTestPoolMap()); -// -// HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, "random", defaultProps); -// assertTrue(readerUrl1Connection.equals(selectedHost.getHost()) -// || readerUrl2Connection.equals(selectedHost.getHost())); -// } -// -// @Test -// public void testLeastConnectionsStrategy() throws SQLException { -// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); -// provider.setDatabasePools(getTestPoolMap()); -// -// HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, LEAST_CONNECTIONS, defaultProps); -// // Other reader has 2 connections -// assertEquals(readerUrl1Connection, selectedHost.getHost()); -// } -// -// private SlidingExpirationCache getTestPoolMap() { -// SlidingExpirationCache map = new SlidingExpirationCache<>(); -// map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user1), -// (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); -// map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user2), -// (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); -// map.computeIfAbsent(Pair.create(readerHost1Connection.getUrl(), user1), -// (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); -// return map; -// } -// -// @Test -// public void testConfigurePool() throws SQLException { -// provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); -// final String expectedJdbcUrl = -// protocol + readerHost1Connection.getUrl() + db + "?database=" + db; -// doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) -// .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); -// -// provider.configurePool(mockConfig, protocol, readerHost1Connection, defaultProps, mockTargetDriverDialect); -// verify(mockConfig).setJdbcUrl(expectedJdbcUrl); -// verify(mockConfig).setUsername(user1); -// verify(mockConfig).setPassword(password); -// } -// -// @Test -// public void testConnectToDeletedInstance() throws SQLException { -// provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); -// -// doReturn(mockDataSource).when(provider) -// .createHikariDataSource(eq(protocol), eq(readerHost1Connection), eq(defaultProps), eq(mockTargetDriverDialect)); -// when(mockDataSource.getConnection()).thenThrow(SQLException.class); -// -// assertThrows(SQLException.class, -// () -> provider.connect(protocol, mockDialect, mockTargetDriverDialect, readerHost1Connection, defaultProps)); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import com.zaxxer.hikari.HikariPoolMXBean; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.targetdriverdialect.ConnectInfo; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.Pair; +import software.amazon.jdbc.util.storage.SlidingExpirationCache; + +class HikariPooledConnectionProviderTest { + @Mock Connection mockConnection; + @Mock HikariDataSource mockDataSource; + @Mock HostSpec mockHostSpec; + @Mock HikariConfig mockConfig; + @Mock Dialect mockDialect; + @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock HikariDataSource dsWithNoConnections; + @Mock HikariDataSource dsWith1Connection; + @Mock HikariDataSource dsWith2Connections; + @Mock HikariPoolMXBean mxBeanWithNoConnections; + @Mock HikariPoolMXBean mxBeanWith1Connection; + @Mock HikariPoolMXBean mxBeanWith2Connections; + private static final String LEAST_CONNECTIONS = "leastConnections"; + private final int port = 5432; + private final String user1 = "user1"; + private final String user2 = "user2"; + private final String password = "password"; + private final String db = "mydb"; + private final String writerUrlNoConnections = "writerWithNoConnections.XYZ.us-east-1.rds.amazonaws.com"; + private final HostSpec writerHostNoConnections = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host(writerUrlNoConnections).port(port).role(HostRole.WRITER).build(); + private final String readerUrl1Connection = "readerWith1connection.XYZ.us-east-1.rds.amazonaws.com"; + private final HostSpec readerHost1Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host(readerUrl1Connection).port(port).role(HostRole.READER).build(); + private final String readerUrl2Connection = "readerWith2connection.XYZ.us-east-1.rds.amazonaws.com"; + private final HostSpec readerHost2Connection = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host(readerUrl2Connection).port(port).role(HostRole.READER).build(); + private final String protocol = "protocol://"; + + private final Properties defaultProps = getDefaultProps(); + private final List testHosts = getTestHosts(); + private HikariPooledConnectionProvider provider; + + private AutoCloseable closeable; + + private List getTestHosts() { + List hosts = new ArrayList<>(); + hosts.add(writerHostNoConnections); + hosts.add(readerHost1Connection); + hosts.add(readerHost2Connection); + return hosts; + } + + private Properties getDefaultProps() { + Properties props = new Properties(); + props.setProperty(PropertyDefinition.USER.name, user1); + props.setProperty(PropertyDefinition.PASSWORD.name, password); + props.setProperty(PropertyDefinition.DATABASE.name, db); + return props; + } + + @BeforeEach + void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + when(mockDataSource.getConnection()).thenReturn(mockConnection); + when(mockConnection.isValid(any(Integer.class))).thenReturn(true); + when(dsWithNoConnections.getHikariPoolMXBean()).thenReturn(mxBeanWithNoConnections); + when(mxBeanWithNoConnections.getActiveConnections()).thenReturn(0); + when(dsWith1Connection.getHikariPoolMXBean()).thenReturn(mxBeanWith1Connection); + when(mxBeanWith1Connection.getActiveConnections()).thenReturn(1); + when(dsWith2Connections.getHikariPoolMXBean()).thenReturn(mxBeanWith2Connections); + when(mxBeanWith2Connections.getActiveConnections()).thenReturn(2); + } + + @AfterEach + void tearDown() throws Exception { + if (provider != null) { + provider.releaseResources(); + } + closeable.close(); + } + + @Test + void testConnectWithDefaultMapping() throws SQLException { + when(mockHostSpec.getUrl()).thenReturn("url"); + final Set expectedUrls = new HashSet<>(Collections.singletonList("url")); + final Set expectedKeys = new HashSet<>( + Collections.singletonList(Pair.create("url", user1))); + + provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); + + doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); + doReturn(new ConnectInfo("url", new Properties())) + .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); + + Properties props = new Properties(); + props.setProperty(PropertyDefinition.USER.name, user1); + props.setProperty(PropertyDefinition.PASSWORD.name, password); + try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { + assertEquals(mockConnection, conn); + assertEquals(1, provider.getHostCount()); + final Set hosts = provider.getHosts(); + assertEquals(expectedUrls, hosts); + final Set keys = provider.getKeys(); + assertEquals(expectedKeys, keys); + } + } + + @Test + void testConnectWithCustomMapping() throws SQLException { + when(mockHostSpec.getUrl()).thenReturn("url"); + final Set expectedKeys = new HashSet<>( + Collections.singletonList(Pair.create("url", "url+someUniqueKey"))); + + provider = spy(new HikariPooledConnectionProvider( + (hostSpec, properties) -> mockConfig, + (hostSpec, properties) -> hostSpec.getUrl() + "+someUniqueKey")); + + doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); + + Properties props = new Properties(); + props.setProperty(PropertyDefinition.USER.name, user1); + props.setProperty(PropertyDefinition.PASSWORD.name, password); + try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { + assertEquals(mockConnection, conn); + assertEquals(1, provider.getHostCount()); + final Set keys = provider.getKeys(); + assertEquals(expectedKeys, keys); + } + } + + @Test + public void testAcceptsUrl() { + final String clusterUrl = "my-database.cluster-XYZ.us-east-1.rds.amazonaws.com"; + provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); + + assertTrue( + provider.acceptsUrl(protocol, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(readerUrl2Connection).build(), + defaultProps)); + assertFalse( + provider.acceptsUrl(protocol, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(clusterUrl).build(), defaultProps)); + } + + @Test + public void testRandomStrategy() throws SQLException { + provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); + provider.setDatabasePools(getTestPoolMap()); + + HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, "random", defaultProps); + assertTrue(readerUrl1Connection.equals(selectedHost.getHost()) + || readerUrl2Connection.equals(selectedHost.getHost())); + } + + @Test + public void testLeastConnectionsStrategy() throws SQLException { + provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); + provider.setDatabasePools(getTestPoolMap()); + + HostSpec selectedHost = provider.getHostSpecByStrategy(testHosts, HostRole.READER, LEAST_CONNECTIONS, defaultProps); + // Other reader has 2 connections + assertEquals(readerUrl1Connection, selectedHost.getHost()); + } + + private SlidingExpirationCache getTestPoolMap() { + SlidingExpirationCache map = new SlidingExpirationCache<>(); + map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user1), + (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); + map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user2), + (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); + map.computeIfAbsent(Pair.create(readerHost1Connection.getUrl(), user1), + (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); + return map; + } + + @Test + public void testConfigurePool() throws SQLException { + provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); + final String expectedJdbcUrl = + protocol + readerHost1Connection.getUrl() + db + "?database=" + db; + doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) + .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); + + provider.configurePool(mockConfig, protocol, readerHost1Connection, defaultProps, mockTargetDriverDialect); + verify(mockConfig).setJdbcUrl(expectedJdbcUrl); + verify(mockConfig).setUsername(user1); + verify(mockConfig).setPassword(password); + } + + @Test + public void testConnectToDeletedInstance() throws SQLException { + provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); + + doReturn(mockDataSource).when(provider) + .createHikariDataSource(eq(protocol), eq(readerHost1Connection), eq(defaultProps), eq(mockTargetDriverDialect)); + when(mockDataSource.getConnection()).thenThrow(SQLException.class); + + assertThrows(SQLException.class, + () -> provider.connect(protocol, mockDialect, mockTargetDriverDialect, readerHost1Connection, defaultProps)); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java index 4b234a741..07a22941f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java @@ -1,946 +1,946 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc; -// -// import static org.junit.jupiter.api.Assertions.assertArrayEquals; -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertFalse; -// import static org.junit.jupiter.api.Assertions.assertNotEquals; -// import static org.junit.jupiter.api.Assertions.assertNull; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doNothing; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.EnumSet; -// import java.util.HashSet; -// import java.util.List; -// import java.util.Map; -// import java.util.Properties; -// import java.util.Set; -// import java.util.stream.Stream; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.Arguments; -// import org.junit.jupiter.params.provider.MethodSource; -// import org.mockito.ArgumentCaptor; -// import org.mockito.Captor; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.dialect.AuroraPgDialect; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.dialect.DialectManager; -// import software.amazon.jdbc.dialect.MysqlDialect; -// import software.amazon.jdbc.exceptions.ExceptionManager; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.profile.ConfigurationProfile; -// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -// import software.amazon.jdbc.states.SessionStateService; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.events.EventPublisher; -// import software.amazon.jdbc.util.storage.StorageService; -// import software.amazon.jdbc.util.storage.TestStorageServiceImpl; -// -// public class PluginServiceImplTests { -// -// private static final Properties PROPERTIES = new Properties(); -// private static final String URL = "url"; -// private static final String DRIVER_PROTOCOL = "driverProtocol"; -// private StorageService storageService; -// private AutoCloseable closeable; -// -// @Mock FullServicesContainer servicesContainer; -// @Mock EventPublisher mockEventPublisher; -// @Mock ConnectionPluginManager pluginManager; -// @Mock Connection newConnection; -// @Mock Connection oldConnection; -// @Mock HostListProvider hostListProvider; -// @Mock DialectManager dialectManager; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// @Mock Statement statement; -// @Mock ResultSet resultSet; -// ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); -// @Mock SessionStateService sessionStateService; -// -// @Captor ArgumentCaptor> argumentChanges; -// @Captor ArgumentCaptor>> argumentChangesMap; -// @Captor ArgumentCaptor argumentSkipPlugin; -// -// @BeforeEach -// void setUp() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// when(oldConnection.isClosed()).thenReturn(false); -// when(newConnection.createStatement()).thenReturn(statement); -// when(statement.executeQuery(any())).thenReturn(resultSet); -// when(servicesContainer.getConnectionPluginManager()).thenReturn(pluginManager); -// when(servicesContainer.getStorageService()).thenReturn(storageService); -// storageService = new TestStorageServiceImpl(mockEventPublisher); -// PluginServiceImpl.hostAvailabilityExpiringCache.clear(); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// storageService.clearAll(); -// PluginServiceImpl.hostAvailabilityExpiringCache.clear(); -// } -// -// @Test -// public void testOldConnectionNoSuggestion() throws SQLException { -// when(pluginManager.notifyConnectionChanged(any(), any())) -// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") -// .build(); -// -// target.setCurrentConnection(newConnection, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); -// -// assertNotEquals(oldConnection, target.currentConnection); -// assertEquals(newConnection, target.currentConnection); -// assertEquals("new-host", target.currentHostSpec.getHost()); -// verify(oldConnection, times(1)).close(); -// } -// -// @Test -// public void testOldConnectionDisposeSuggestion() throws SQLException { -// when(pluginManager.notifyConnectionChanged(any(), any())) -// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.DISPOSE)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") -// .build(); -// -// target.setCurrentConnection(newConnection, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); -// -// assertNotEquals(oldConnection, target.currentConnection); -// assertEquals(newConnection, target.currentConnection); -// assertEquals("new-host", target.currentHostSpec.getHost()); -// verify(oldConnection, times(1)).close(); -// } -// -// @Test -// public void testOldConnectionPreserveSuggestion() throws SQLException { -// when(pluginManager.notifyConnectionChanged(any(), any())) -// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.PRESERVE)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") -// .build(); -// -// target.setCurrentConnection(newConnection, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); -// -// assertNotEquals(oldConnection, target.currentConnection); -// assertEquals(newConnection, target.currentConnection); -// assertEquals("new-host", target.currentHostSpec.getHost()); -// verify(oldConnection, times(0)).close(); -// } -// -// @Test -// public void testOldConnectionMixedSuggestion() throws SQLException { -// when(pluginManager.notifyConnectionChanged(any(), any())) -// .thenReturn( -// EnumSet.of( -// OldConnectionSuggestedAction.NO_OPINION, -// OldConnectionSuggestedAction.PRESERVE, -// OldConnectionSuggestedAction.DISPOSE)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") -// .build(); -// -// target.setCurrentConnection(newConnection, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); -// -// assertNotEquals(oldConnection, target.currentConnection); -// assertEquals(newConnection, target.currentConnection); -// assertEquals("new-host", target.currentHostSpec.getHost()); -// verify(oldConnection, times(0)).close(); -// } -// -// @Test -// public void testChangesNewConnectionNewHostNewPortNewRoleNewAvailability() throws SQLException { -// when(pluginManager.notifyConnectionChanged( -// argumentChanges.capture(), argumentSkipPlugin.capture())) -// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).build(); -// -// target.setCurrentConnection( -// newConnection, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("new-host").port(2000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) -// .build()); -// -// assertNull(argumentSkipPlugin.getValue()); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); -// } -// -// @Test -// public void testChangesNewConnectionNewRoleNewAvailability() throws SQLException { -// when(pluginManager.notifyConnectionChanged( -// argumentChanges.capture(), argumentSkipPlugin.capture())) -// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) -// .build(); -// -// target.setCurrentConnection(newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE) -// .build()); -// -// assertNull(argumentSkipPlugin.getValue()); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); -// } -// -// @Test -// public void testChangesNewConnection() throws SQLException { -// when(pluginManager.notifyConnectionChanged( -// argumentChanges.capture(), argumentSkipPlugin.capture())) -// .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) -// .build(); -// -// target.setCurrentConnection( -// newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) -// .build()); -// -// assertNull(argumentSkipPlugin.getValue()); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); -// assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); -// assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); -// } -// -// @Test -// public void testChangesNoChanges() throws SQLException { -// when(pluginManager.notifyConnectionChanged(any(), any())).thenReturn( -// EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); -// -// PluginServiceImpl target = -// spy(new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.currentConnection = oldConnection; -// target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(); -// -// target.setCurrentConnection( -// oldConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) -// .build()); -// -// verify(pluginManager, times(0)).notifyConnectionChanged(any(), any()); -// } -// -// @Test -// public void testSetNodeListAdded() throws SQLException { -// -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// when(hostListProvider.refresh()).thenReturn(Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build())); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.allHosts = new ArrayList<>(); -// target.hostListProvider = hostListProvider; -// -// target.refreshHostList(); -// -// assertEquals(1, target.getAllHosts().size()); -// assertEquals("hostA", target.getAllHosts().get(0).getHost()); -// verify(pluginManager, times(1)).notifyNodeListChanged(any()); -// -// Map> notifiedChanges = argumentChangesMap.getValue(); -// assertTrue(notifiedChanges.containsKey("hostA/")); -// EnumSet hostAChanges = notifiedChanges.get("hostA/"); -// assertEquals(1, hostAChanges.size()); -// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_ADDED)); -// } -// -// @Test -// public void testSetNodeListDeleted() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// when(hostListProvider.refresh()).thenReturn(Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build())); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.allHosts = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build()); -// target.hostListProvider = hostListProvider; -// -// target.refreshHostList(); -// -// assertEquals(1, target.getAllHosts().size()); -// assertEquals("hostB", target.getAllHosts().get(0).getHost()); -// verify(pluginManager, times(1)).notifyNodeListChanged(any()); -// -// Map> notifiedChanges = argumentChangesMap.getValue(); -// assertTrue(notifiedChanges.containsKey("hostA/")); -// EnumSet hostAChanges = notifiedChanges.get("hostA/"); -// assertEquals(1, hostAChanges.size()); -// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_DELETED)); -// } -// -// @Test -// public void testSetNodeListChanged() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// when(hostListProvider.refresh()).thenReturn( -// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") -// .port(HostSpec.NO_PORT).role(HostRole.READER).build())); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.WRITER).build()); -// target.hostListProvider = hostListProvider; -// -// target.refreshHostList(); -// -// assertEquals(1, target.getAllHosts().size()); -// assertEquals("hostA", target.getAllHosts().get(0).getHost()); -// verify(pluginManager, times(1)).notifyNodeListChanged(any()); -// -// Map> notifiedChanges = argumentChangesMap.getValue(); -// assertTrue(notifiedChanges.containsKey("hostA/")); -// EnumSet hostAChanges = notifiedChanges.get("hostA/"); -// assertEquals(2, hostAChanges.size()); -// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); -// assertTrue(hostAChanges.contains(NodeChangeOptions.PROMOTED_TO_READER)); -// } -// -// @Test -// public void testSetNodeListNoChanges() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(any()); -// -// when(hostListProvider.refresh()).thenReturn( -// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build())); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build()); -// target.hostListProvider = hostListProvider; -// -// target.refreshHostList(); -// -// assertEquals(1, target.getAllHosts().size()); -// assertEquals("hostA", target.getAllHosts().get(0).getHost()); -// verify(pluginManager, times(0)).notifyNodeListChanged(any()); -// } -// -// @Test -// public void testNodeAvailabilityNotChanged() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.allHosts = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) -// .build()); -// -// Set aliases = new HashSet<>(); -// aliases.add("hostA"); -// target.setAvailability(aliases, HostAvailability.AVAILABLE); -// -// assertEquals(1, target.getAllHosts().size()); -// assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); -// verify(pluginManager, never()).notifyNodeListChanged(any()); -// } -// -// @Test -// public void testNodeAvailabilityChanged_WentDown() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.allHosts = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) -// .build()); -// -// Set aliases = new HashSet<>(); -// aliases.add("hostA"); -// target.setAvailability(aliases, HostAvailability.NOT_AVAILABLE); -// -// assertEquals(1, target.getAllHosts().size()); -// assertEquals(HostAvailability.NOT_AVAILABLE, target.getAllHosts().get(0).getAvailability()); -// verify(pluginManager, times(1)).notifyNodeListChanged(any()); -// -// Map> notifiedChanges = argumentChangesMap.getValue(); -// assertTrue(notifiedChanges.containsKey("hostA/")); -// EnumSet hostAChanges = notifiedChanges.get("hostA/"); -// assertEquals(2, hostAChanges.size()); -// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); -// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_DOWN)); -// } -// -// @Test -// public void testNodeAvailabilityChanged_WentUp() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.allHosts = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) -// .build()); -// -// Set aliases = new HashSet<>(); -// aliases.add("hostA"); -// target.setAvailability(aliases, HostAvailability.AVAILABLE); -// -// assertEquals(1, target.getAllHosts().size()); -// assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); -// verify(pluginManager, times(1)).notifyNodeListChanged(any()); -// -// Map> notifiedChanges = argumentChangesMap.getValue(); -// assertTrue(notifiedChanges.containsKey("hostA/")); -// EnumSet hostAChanges = notifiedChanges.get("hostA/"); -// assertEquals(2, hostAChanges.size()); -// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); -// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); -// } -// -// @Test -// public void testNodeAvailabilityChanged_WentUp_ByAlias() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) -// .build(); -// hostA.addAlias("ip-10-10-10-10"); -// hostA.addAlias("hostA.custom.domain.com"); -// final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) -// .build(); -// hostB.addAlias("ip-10-10-10-10"); -// hostB.addAlias("hostB.custom.domain.com"); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// -// target.allHosts = Arrays.asList(hostA, hostB); -// -// Set aliases = new HashSet<>(); -// aliases.add("hostA.custom.domain.com"); -// target.setAvailability(aliases, HostAvailability.AVAILABLE); -// -// assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); -// assertEquals(HostAvailability.NOT_AVAILABLE, hostB.getAvailability()); -// verify(pluginManager, times(1)).notifyNodeListChanged(any()); -// -// Map> notifiedChanges = argumentChangesMap.getValue(); -// assertTrue(notifiedChanges.containsKey("hostA/")); -// EnumSet hostAChanges = notifiedChanges.get("hostA/"); -// assertEquals(2, hostAChanges.size()); -// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); -// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); -// } -// -// @Test -// public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQLException { -// doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); -// -// final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) -// .build();; -// hostA.addAlias("ip-10-10-10-10"); -// hostA.addAlias("hostA.custom.domain.com"); -// final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) -// .build(); -// hostB.addAlias("ip-10-10-10-10"); -// hostB.addAlias("hostB.custom.domain.com"); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// -// target.allHosts = Arrays.asList(hostA, hostB); -// -// Set aliases = new HashSet<>(); -// aliases.add("ip-10-10-10-10"); -// target.setAvailability(aliases, HostAvailability.AVAILABLE); -// -// assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); -// assertEquals(HostAvailability.AVAILABLE, hostB.getAvailability()); -// verify(pluginManager, times(1)).notifyNodeListChanged(any()); -// -// Map> notifiedChanges = argumentChangesMap.getValue(); -// assertTrue(notifiedChanges.containsKey("hostA/")); -// EnumSet hostAChanges = notifiedChanges.get("hostA/"); -// assertEquals(2, hostAChanges.size()); -// assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); -// assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); -// -// assertTrue(notifiedChanges.containsKey("hostB/")); -// EnumSet hostBChanges = notifiedChanges.get("hostB/"); -// assertEquals(2, hostBChanges.size()); -// assertTrue(hostBChanges.contains(NodeChangeOptions.NODE_CHANGED)); -// assertTrue(hostBChanges.contains(NodeChangeOptions.WENT_UP)); -// } -// -// @Test -// void testRefreshHostList_withCachedHostAvailability() throws SQLException { -// final List newHostSpecs = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() -// ); -// final List newHostSpecs2 = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() -// ); -// final List expectedHostSpecs = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() -// ); -// final List expectedHostSpecs2 = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() -// ); -// -// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, -// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); -// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, -// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); -// when(hostListProvider.refresh()).thenReturn(newHostSpecs); -// when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs2); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// when(target.getHostListProvider()).thenReturn(hostListProvider); -// -// assertNotEquals(expectedHostSpecs, newHostSpecs); -// target.refreshHostList(); -// assertEquals(expectedHostSpecs, newHostSpecs); -// -// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, -// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); -// target.refreshHostList(newConnection); -// assertEquals(expectedHostSpecs2, newHostSpecs); -// } -// -// @Test -// void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { -// final List newHostSpecs = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() -// ); -// final List expectedHostSpecs = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() -// ); -// final List expectedHostSpecs2 = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) -// .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() -// ); -// -// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, -// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); -// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, -// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); -// when(hostListProvider.forceRefresh()).thenReturn(newHostSpecs); -// when(hostListProvider.forceRefresh(newConnection)).thenReturn(newHostSpecs); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// when(target.getHostListProvider()).thenReturn(hostListProvider); -// -// assertNotEquals(expectedHostSpecs, newHostSpecs); -// target.forceRefreshHostList(); -// assertEquals(expectedHostSpecs, newHostSpecs); -// -// PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, -// PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); -// target.forceRefreshHostList(newConnection); -// assertEquals(expectedHostSpecs2, newHostSpecs); -// } -// -// @Test -// void testIdentifyConnectionWithNoAliases() throws SQLException { -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// when(target.getHostListProvider()).thenReturn(hostListProvider); -// -// when(target.getDialect()).thenReturn(new MysqlDialect()); -// assertNull(target.identifyConnection(newConnection)); -// } -// -// @Test -// void testIdentifyConnectionWithAliases() throws SQLException { -// final HostSpec expected = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("test") -// .build(); -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.hostListProvider = hostListProvider; -// when(target.getHostListProvider()).thenReturn(hostListProvider); -// when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(expected); -// -// when(target.getDialect()).thenReturn(new AuroraPgDialect()); -// final HostSpec actual = target.identifyConnection(newConnection); -// verify(target, never()).getCurrentHostSpec(); -// verify(hostListProvider).identifyConnection(newConnection); -// assertEquals(expected, actual); -// } -// -// @Test -// void testFillAliasesNonEmptyAliases() throws SQLException { -// final HostSpec oneAlias = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo") -// .build(); -// oneAlias.addAlias(oneAlias.asAlias()); -// -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// -// assertEquals(1, oneAlias.getAliases().size()); -// target.fillAliases(newConnection, oneAlias); -// // Fill aliases should return directly and no additional aliases should be added. -// assertEquals(1, oneAlias.getAliases().size()); -// } -// -// @ParameterizedTest -// @MethodSource("fillAliasesDialects") -// void testFillAliasesWithInstanceEndpoint(Dialect dialect, String[] expectedInstanceAliases) throws SQLException { -// final HostSpec empty = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo").build(); -// PluginServiceImpl target = spy( -// new PluginServiceImpl( -// servicesContainer, -// new ExceptionManager(), -// PROPERTIES, -// URL, -// DRIVER_PROTOCOL, -// dialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// sessionStateService)); -// target.hostListProvider = hostListProvider; -// when(target.getDialect()).thenReturn(dialect); -// when(resultSet.next()).thenReturn(true, false); // Result set contains 1 row. -// when(resultSet.getString(eq(1))).thenReturn("ip"); -// if (dialect instanceof AuroraPgDialect) { -// when(hostListProvider.identifyConnection(eq(newConnection))) -// .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance").build()); -// } -// -// target.fillAliases(newConnection, empty); -// -// final String[] aliases = empty.getAliases().toArray(new String[] {}); -// assertArrayEquals(expectedInstanceAliases, aliases); -// } -// -// private static Stream fillAliasesDialects() { -// return Stream.of( -// Arguments.of(new AuroraPgDialect(), new String[]{"instance", "foo", "ip"}), -// Arguments.of(new MysqlDialect(), new String[]{"foo", "ip"}) -// ); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.dialect.AuroraPgDialect; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.dialect.DialectManager; +import software.amazon.jdbc.dialect.MysqlDialect; +import software.amazon.jdbc.exceptions.ExceptionManager; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.profile.ConfigurationProfile; +import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.events.EventPublisher; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.storage.TestStorageServiceImpl; + +public class PluginServiceImplTests { + + private static final Properties PROPERTIES = new Properties(); + private static final String URL = "url"; + private static final String DRIVER_PROTOCOL = "driverProtocol"; + private StorageService storageService; + private AutoCloseable closeable; + + @Mock FullServicesContainer servicesContainer; + @Mock EventPublisher mockEventPublisher; + @Mock ConnectionPluginManager pluginManager; + @Mock Connection newConnection; + @Mock Connection oldConnection; + @Mock HostListProvider hostListProvider; + @Mock DialectManager dialectManager; + @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock Statement statement; + @Mock ResultSet resultSet; + ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); + @Mock SessionStateService sessionStateService; + + @Captor ArgumentCaptor> argumentChanges; + @Captor ArgumentCaptor>> argumentChangesMap; + @Captor ArgumentCaptor argumentSkipPlugin; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + when(oldConnection.isClosed()).thenReturn(false); + when(newConnection.createStatement()).thenReturn(statement); + when(statement.executeQuery(any())).thenReturn(resultSet); + when(servicesContainer.getConnectionPluginManager()).thenReturn(pluginManager); + when(servicesContainer.getStorageService()).thenReturn(storageService); + storageService = new TestStorageServiceImpl(mockEventPublisher); + PluginServiceImpl.hostAvailabilityExpiringCache.clear(); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + storageService.clearAll(); + PluginServiceImpl.hostAvailabilityExpiringCache.clear(); + } + + @Test + public void testOldConnectionNoSuggestion() throws SQLException { + when(pluginManager.notifyConnectionChanged(any(), any())) + .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") + .build(); + + target.setCurrentConnection(newConnection, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); + + assertNotEquals(oldConnection, target.currentConnection); + assertEquals(newConnection, target.currentConnection); + assertEquals("new-host", target.currentHostSpec.getHost()); + verify(oldConnection, times(1)).close(); + } + + @Test + public void testOldConnectionDisposeSuggestion() throws SQLException { + when(pluginManager.notifyConnectionChanged(any(), any())) + .thenReturn(EnumSet.of(OldConnectionSuggestedAction.DISPOSE)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") + .build(); + + target.setCurrentConnection(newConnection, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); + + assertNotEquals(oldConnection, target.currentConnection); + assertEquals(newConnection, target.currentConnection); + assertEquals("new-host", target.currentHostSpec.getHost()); + verify(oldConnection, times(1)).close(); + } + + @Test + public void testOldConnectionPreserveSuggestion() throws SQLException { + when(pluginManager.notifyConnectionChanged(any(), any())) + .thenReturn(EnumSet.of(OldConnectionSuggestedAction.PRESERVE)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") + .build(); + + target.setCurrentConnection(newConnection, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); + + assertNotEquals(oldConnection, target.currentConnection); + assertEquals(newConnection, target.currentConnection); + assertEquals("new-host", target.currentHostSpec.getHost()); + verify(oldConnection, times(0)).close(); + } + + @Test + public void testOldConnectionMixedSuggestion() throws SQLException { + when(pluginManager.notifyConnectionChanged(any(), any())) + .thenReturn( + EnumSet.of( + OldConnectionSuggestedAction.NO_OPINION, + OldConnectionSuggestedAction.PRESERVE, + OldConnectionSuggestedAction.DISPOSE)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") + .build(); + + target.setCurrentConnection(newConnection, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("new-host").build()); + + assertNotEquals(oldConnection, target.currentConnection); + assertEquals(newConnection, target.currentConnection); + assertEquals("new-host", target.currentHostSpec.getHost()); + verify(oldConnection, times(0)).close(); + } + + @Test + public void testChangesNewConnectionNewHostNewPortNewRoleNewAvailability() throws SQLException { + when(pluginManager.notifyConnectionChanged( + argumentChanges.capture(), argumentSkipPlugin.capture())) + .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).build(); + + target.setCurrentConnection( + newConnection, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("new-host").port(2000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) + .build()); + + assertNull(argumentSkipPlugin.getValue()); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); + } + + @Test + public void testChangesNewConnectionNewRoleNewAvailability() throws SQLException { + when(pluginManager.notifyConnectionChanged( + argumentChanges.capture(), argumentSkipPlugin.capture())) + .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) + .build(); + + target.setCurrentConnection(newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE) + .build()); + + assertNull(argumentSkipPlugin.getValue()); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); + } + + @Test + public void testChangesNewConnection() throws SQLException { + when(pluginManager.notifyConnectionChanged( + argumentChanges.capture(), argumentSkipPlugin.capture())) + .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) + .build(); + + target.setCurrentConnection( + newConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) + .build()); + + assertNull(argumentSkipPlugin.getValue()); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_CHANGED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_ADDED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.NODE_DELETED)); + assertTrue(argumentChanges.getValue().contains(NodeChangeOptions.CONNECTION_OBJECT_CHANGED)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.HOSTNAME)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_READER)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.PROMOTED_TO_WRITER)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_DOWN)); + assertFalse(argumentChanges.getValue().contains(NodeChangeOptions.WENT_UP)); + } + + @Test + public void testChangesNoChanges() throws SQLException { + when(pluginManager.notifyConnectionChanged(any(), any())).thenReturn( + EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); + + PluginServiceImpl target = + spy(new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.currentConnection = oldConnection; + target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(); + + target.setCurrentConnection( + oldConnection, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE) + .build()); + + verify(pluginManager, times(0)).notifyConnectionChanged(any(), any()); + } + + @Test + public void testSetNodeListAdded() throws SQLException { + + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + when(hostListProvider.refresh()).thenReturn(Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build())); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.allHosts = new ArrayList<>(); + target.hostListProvider = hostListProvider; + + target.refreshHostList(); + + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostA", target.getAllHosts().get(0).getHost()); + verify(pluginManager, times(1)).notifyNodeListChanged(any()); + + Map> notifiedChanges = argumentChangesMap.getValue(); + assertTrue(notifiedChanges.containsKey("hostA/")); + EnumSet hostAChanges = notifiedChanges.get("hostA/"); + assertEquals(1, hostAChanges.size()); + assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_ADDED)); + } + + @Test + public void testSetNodeListDeleted() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + when(hostListProvider.refresh()).thenReturn(Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build())); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.allHosts = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build()); + target.hostListProvider = hostListProvider; + + target.refreshHostList(); + + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostB", target.getAllHosts().get(0).getHost()); + verify(pluginManager, times(1)).notifyNodeListChanged(any()); + + Map> notifiedChanges = argumentChangesMap.getValue(); + assertTrue(notifiedChanges.containsKey("hostA/")); + EnumSet hostAChanges = notifiedChanges.get("hostA/"); + assertEquals(1, hostAChanges.size()); + assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_DELETED)); + } + + @Test + public void testSetNodeListChanged() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + when(hostListProvider.refresh()).thenReturn( + Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") + .port(HostSpec.NO_PORT).role(HostRole.READER).build())); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.WRITER).build()); + target.hostListProvider = hostListProvider; + + target.refreshHostList(); + + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostA", target.getAllHosts().get(0).getHost()); + verify(pluginManager, times(1)).notifyNodeListChanged(any()); + + Map> notifiedChanges = argumentChangesMap.getValue(); + assertTrue(notifiedChanges.containsKey("hostA/")); + EnumSet hostAChanges = notifiedChanges.get("hostA/"); + assertEquals(2, hostAChanges.size()); + assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); + assertTrue(hostAChanges.contains(NodeChangeOptions.PROMOTED_TO_READER)); + } + + @Test + public void testSetNodeListNoChanges() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(any()); + + when(hostListProvider.refresh()).thenReturn( + Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build())); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build()); + target.hostListProvider = hostListProvider; + + target.refreshHostList(); + + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostA", target.getAllHosts().get(0).getHost()); + verify(pluginManager, times(0)).notifyNodeListChanged(any()); + } + + @Test + public void testNodeAvailabilityNotChanged() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.allHosts = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) + .build()); + + Set aliases = new HashSet<>(); + aliases.add("hostA"); + target.setAvailability(aliases, HostAvailability.AVAILABLE); + + assertEquals(1, target.getAllHosts().size()); + assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); + verify(pluginManager, never()).notifyNodeListChanged(any()); + } + + @Test + public void testNodeAvailabilityChanged_WentDown() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.allHosts = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) + .build()); + + Set aliases = new HashSet<>(); + aliases.add("hostA"); + target.setAvailability(aliases, HostAvailability.NOT_AVAILABLE); + + assertEquals(1, target.getAllHosts().size()); + assertEquals(HostAvailability.NOT_AVAILABLE, target.getAllHosts().get(0).getAvailability()); + verify(pluginManager, times(1)).notifyNodeListChanged(any()); + + Map> notifiedChanges = argumentChangesMap.getValue(); + assertTrue(notifiedChanges.containsKey("hostA/")); + EnumSet hostAChanges = notifiedChanges.get("hostA/"); + assertEquals(2, hostAChanges.size()); + assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); + assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_DOWN)); + } + + @Test + public void testNodeAvailabilityChanged_WentUp() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.allHosts = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) + .build()); + + Set aliases = new HashSet<>(); + aliases.add("hostA"); + target.setAvailability(aliases, HostAvailability.AVAILABLE); + + assertEquals(1, target.getAllHosts().size()); + assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); + verify(pluginManager, times(1)).notifyNodeListChanged(any()); + + Map> notifiedChanges = argumentChangesMap.getValue(); + assertTrue(notifiedChanges.containsKey("hostA/")); + EnumSet hostAChanges = notifiedChanges.get("hostA/"); + assertEquals(2, hostAChanges.size()); + assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); + assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); + } + + @Test + public void testNodeAvailabilityChanged_WentUp_ByAlias() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) + .build(); + hostA.addAlias("ip-10-10-10-10"); + hostA.addAlias("hostA.custom.domain.com"); + final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) + .build(); + hostB.addAlias("ip-10-10-10-10"); + hostB.addAlias("hostB.custom.domain.com"); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + + target.allHosts = Arrays.asList(hostA, hostB); + + Set aliases = new HashSet<>(); + aliases.add("hostA.custom.domain.com"); + target.setAvailability(aliases, HostAvailability.AVAILABLE); + + assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); + assertEquals(HostAvailability.NOT_AVAILABLE, hostB.getAvailability()); + verify(pluginManager, times(1)).notifyNodeListChanged(any()); + + Map> notifiedChanges = argumentChangesMap.getValue(); + assertTrue(notifiedChanges.containsKey("hostA/")); + EnumSet hostAChanges = notifiedChanges.get("hostA/"); + assertEquals(2, hostAChanges.size()); + assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); + assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); + } + + @Test + public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQLException { + doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); + + final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) + .build();; + hostA.addAlias("ip-10-10-10-10"); + hostA.addAlias("hostA.custom.domain.com"); + final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("hostB").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) + .build(); + hostB.addAlias("ip-10-10-10-10"); + hostB.addAlias("hostB.custom.domain.com"); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + + target.allHosts = Arrays.asList(hostA, hostB); + + Set aliases = new HashSet<>(); + aliases.add("ip-10-10-10-10"); + target.setAvailability(aliases, HostAvailability.AVAILABLE); + + assertEquals(HostAvailability.AVAILABLE, hostA.getAvailability()); + assertEquals(HostAvailability.AVAILABLE, hostB.getAvailability()); + verify(pluginManager, times(1)).notifyNodeListChanged(any()); + + Map> notifiedChanges = argumentChangesMap.getValue(); + assertTrue(notifiedChanges.containsKey("hostA/")); + EnumSet hostAChanges = notifiedChanges.get("hostA/"); + assertEquals(2, hostAChanges.size()); + assertTrue(hostAChanges.contains(NodeChangeOptions.NODE_CHANGED)); + assertTrue(hostAChanges.contains(NodeChangeOptions.WENT_UP)); + + assertTrue(notifiedChanges.containsKey("hostB/")); + EnumSet hostBChanges = notifiedChanges.get("hostB/"); + assertEquals(2, hostBChanges.size()); + assertTrue(hostBChanges.contains(NodeChangeOptions.NODE_CHANGED)); + assertTrue(hostBChanges.contains(NodeChangeOptions.WENT_UP)); + } + + @Test + void testRefreshHostList_withCachedHostAvailability() throws SQLException { + final List newHostSpecs = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() + ); + final List newHostSpecs2 = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() + ); + final List expectedHostSpecs = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() + ); + final List expectedHostSpecs2 = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() + ); + + PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, + PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, + PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + when(hostListProvider.refresh()).thenReturn(newHostSpecs); + when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs2); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + when(target.getHostListProvider()).thenReturn(hostListProvider); + + assertNotEquals(expectedHostSpecs, newHostSpecs); + target.refreshHostList(); + assertEquals(expectedHostSpecs, newHostSpecs); + + PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, + PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + target.refreshHostList(newConnection); + assertEquals(expectedHostSpecs2, newHostSpecs); + } + + @Test + void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { + final List newHostSpecs = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() + ); + final List expectedHostSpecs = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() + ); + final List expectedHostSpecs2 = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostC").port(HostSpec.NO_PORT) + .role(HostRole.READER).availability(HostAvailability.AVAILABLE).build() + ); + + PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, + PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, + PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + when(hostListProvider.forceRefresh()).thenReturn(newHostSpecs); + when(hostListProvider.forceRefresh(newConnection)).thenReturn(newHostSpecs); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + when(target.getHostListProvider()).thenReturn(hostListProvider); + + assertNotEquals(expectedHostSpecs, newHostSpecs); + target.forceRefreshHostList(); + assertEquals(expectedHostSpecs, newHostSpecs); + + PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.AVAILABLE, + PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + target.forceRefreshHostList(newConnection); + assertEquals(expectedHostSpecs2, newHostSpecs); + } + + @Test + void testIdentifyConnectionWithNoAliases() throws SQLException { + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + when(target.getHostListProvider()).thenReturn(hostListProvider); + + when(target.getDialect()).thenReturn(new MysqlDialect()); + assertNull(target.identifyConnection(newConnection)); + } + + @Test + void testIdentifyConnectionWithAliases() throws SQLException { + final HostSpec expected = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("test") + .build(); + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.hostListProvider = hostListProvider; + when(target.getHostListProvider()).thenReturn(hostListProvider); + when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(expected); + + when(target.getDialect()).thenReturn(new AuroraPgDialect()); + final HostSpec actual = target.identifyConnection(newConnection); + verify(target, never()).getCurrentHostSpec(); + verify(hostListProvider).identifyConnection(newConnection); + assertEquals(expected, actual); + } + + @Test + void testFillAliasesNonEmptyAliases() throws SQLException { + final HostSpec oneAlias = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo") + .build(); + oneAlias.addAlias(oneAlias.asAlias()); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + + assertEquals(1, oneAlias.getAliases().size()); + target.fillAliases(newConnection, oneAlias); + // Fill aliases should return directly and no additional aliases should be added. + assertEquals(1, oneAlias.getAliases().size()); + } + + @ParameterizedTest + @MethodSource("fillAliasesDialects") + void testFillAliasesWithInstanceEndpoint(Dialect dialect, String[] expectedInstanceAliases) throws SQLException { + final HostSpec empty = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo").build(); + PluginServiceImpl target = spy( + new PluginServiceImpl( + servicesContainer, + new ExceptionManager(), + PROPERTIES, + URL, + DRIVER_PROTOCOL, + dialectManager, + mockTargetDriverDialect, + configurationProfile, + sessionStateService)); + target.hostListProvider = hostListProvider; + when(target.getDialect()).thenReturn(dialect); + when(resultSet.next()).thenReturn(true, false); // Result set contains 1 row. + when(resultSet.getString(eq(1))).thenReturn("ip"); + if (dialect instanceof AuroraPgDialect) { + when(hostListProvider.identifyConnection(eq(newConnection))) + .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance").build()); + } + + target.fillAliases(newConnection, empty); + + final String[] aliases = empty.getAliases().toArray(new String[] {}); + assertArrayEquals(expectedInstanceAliases, aliases); + } + + private static Stream fillAliasesDialects() { + return Stream.of( + Arguments.of(new AuroraPgDialect(), new String[]{"instance", "foo", "ip"}), + Arguments.of(new MysqlDialect(), new String[]{"foo", "ip"}) + ); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java index f0abf31c2..797d151be 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java @@ -1,629 +1,629 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.hostlistprovider; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertFalse; -// import static org.junit.jupiter.api.Assertions.assertNotEquals; -// import static org.junit.jupiter.api.Assertions.assertNotNull; -// import static org.junit.jupiter.api.Assertions.assertNull; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.atMostOnce; -// import static org.mockito.Mockito.doAnswer; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.mock; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import com.mysql.cj.exceptions.WrongArgumentException; -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.SQLSyntaxErrorException; -// import java.sql.Statement; -// import java.sql.Timestamp; -// import java.time.Instant; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.List; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.ArgumentCaptor; -// import org.mockito.Captor; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.events.EventPublisher; -// import software.amazon.jdbc.util.storage.StorageService; -// import software.amazon.jdbc.util.storage.TestStorageServiceImpl; -// -// class RdsHostListProviderTest { -// private StorageService storageService; -// private RdsHostListProvider rdsHostListProvider; -// -// @Mock private Connection mockConnection; -// @Mock private Statement mockStatement; -// @Mock private ResultSet mockResultSet; -// @Mock private FullServicesContainer mockServicesContainer; -// @Mock private PluginService mockPluginService; -// @Mock private HostListProviderService mockHostListProviderService; -// @Mock private EventPublisher mockEventPublisher; -// @Mock Dialect mockTopologyAwareDialect; -// @Captor private ArgumentCaptor queryCaptor; -// -// private AutoCloseable closeable; -// private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("foo").port(1234).build(); -// private final List hosts = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); -// -// @BeforeEach -// void setUp() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// storageService = new TestStorageServiceImpl(mockEventPublisher); -// when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); -// when(mockServicesContainer.getStorageService()).thenReturn(storageService); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); -// when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); -// when(mockConnection.createStatement()).thenReturn(mockStatement); -// when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); -// when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); -// when(mockHostListProviderService.getHostSpecBuilder()) -// .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); -// when(mockHostListProviderService.getCurrentConnection()).thenReturn(mockConnection); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// RdsHostListProvider.clearAll(); -// storageService.clearAll(); -// closeable.close(); -// } -// -// private RdsHostListProvider getRdsHostListProvider(String originalUrl) throws SQLException { -// RdsHostListProvider provider = new RdsHostListProvider( -// new Properties(), -// originalUrl, -// mockServicesContainer, -// "foo", "bar", "baz"); -// provider.init(); -// return provider; -// } -// -// @Test -// void testGetTopology_returnCachedTopology() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("protocol://url/")); -// -// final List expected = hosts; -// storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); -// -// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); -// assertEquals(expected, result.hosts); -// assertEquals(2, result.hosts.size()); -// verify(rdsHostListProvider, never()).queryForTopology(mockConnection); -// } -// -// @Test -// void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// rdsHostListProvider.isInitialized = true; -// -// storageService.set(rdsHostListProvider.clusterId, new Topology(hosts)); -// -// final List newHosts = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); -// doReturn(newHosts).when(rdsHostListProvider).queryForTopology(mockConnection); -// -// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); -// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// assertEquals(1, result.hosts.size()); -// assertEquals(newHosts, result.hosts); -// } -// -// @Test -// void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// rdsHostListProvider.clusterId = "cluster-id"; -// rdsHostListProvider.isInitialized = true; -// -// final List expected = hosts; -// storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); -// -// doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); -// -// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); -// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// assertEquals(2, result.hosts.size()); -// assertEquals(expected, result.hosts); -// } -// -// @Test -// void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// rdsHostListProvider.clear(); -// -// doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); -// -// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); -// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// assertNotNull(result.hosts); -// assertEquals( -// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), -// result.hosts); -// } -// -// @Test -// void testQueryForTopology_withDifferentDriverProtocol() throws SQLException { -// final List expectedMySQL = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("mysql").port(HostSpec.NO_PORT) -// .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); -// final List expectedPostgres = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("postgresql").port(HostSpec.NO_PORT) -// .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); -// when(mockResultSet.next()).thenReturn(true, false); -// when(mockResultSet.getBoolean(eq(2))).thenReturn(true); -// when(mockResultSet.getString(eq(1))).thenReturn("mysql"); -// -// -// rdsHostListProvider = getRdsHostListProvider("mysql://url/"); -// -// List hosts = rdsHostListProvider.queryForTopology(mockConnection); -// assertEquals(expectedMySQL, hosts); -// -// when(mockResultSet.next()).thenReturn(true, false); -// when(mockResultSet.getString(eq(1))).thenReturn("postgresql"); -// -// rdsHostListProvider = getRdsHostListProvider("postgresql://url/"); -// hosts = rdsHostListProvider.queryForTopology(mockConnection); -// assertEquals(expectedPostgres, hosts); -// } -// -// @Test -// void testQueryForTopology_queryResultsInException() throws SQLException { -// rdsHostListProvider = getRdsHostListProvider("protocol://url/"); -// when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); -// -// assertThrows( -// SQLException.class, -// () -> rdsHostListProvider.queryForTopology(mockConnection)); -// } -// -// @Test -// void testGetCachedTopology_returnStoredTopology() throws SQLException { -// rdsHostListProvider = getRdsHostListProvider("jdbc:someprotocol://url"); -// -// final List expected = hosts; -// storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); -// -// final List result = rdsHostListProvider.getStoredTopology(); -// assertEquals(expected, result); -// } -// -// @Test -// void testTopologyCache_NoSuggestedClusterId() throws SQLException { -// RdsHostListProvider.clearAll(); -// -// RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.domain.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); -// -// doReturn(topologyClusterA) -// .when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// final List topologyProvider1 = provider1.refresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-b.domain.com/")); -// provider2.init(); -// assertNull(provider2.getStoredTopology()); -// -// final List topologyClusterB = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); -// doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); -// -// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); -// assertEquals(topologyClusterB, topologyProvider2); -// -// assertEquals(2, storageService.size(Topology.class)); -// } -// -// @Test -// void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { -// RdsHostListProvider.clearAll(); -// -// RdsHostListProvider provider1 = -// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build()); -// -// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// final List topologyProvider1 = provider1.refresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// RdsHostListProvider provider2 = -// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider2.init(); -// -// assertEquals(provider1.clusterId, provider2.clusterId); -// assertTrue(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// -// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider2); -// -// assertEquals(1, storageService.size(Topology.class)); -// } -// -// @Test -// void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { -// RdsHostListProvider.clearAll(); -// -// RdsHostListProvider provider1 = -// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build()); -// -// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// final List topologyProvider1 = provider1.refresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// RdsHostListProvider provider2 = -// Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); -// provider2.init(); -// -// assertEquals(provider1.clusterId, provider2.clusterId); -// assertTrue(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// -// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider2); -// -// assertEquals(1, storageService.size(Topology.class)); -// } -// -// @Test -// void testTopologyCache_AcceptSuggestion() throws SQLException { -// RdsHostListProvider.clearAll(); -// -// RdsHostListProvider provider1 = -// Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build()); -// -// doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// List topologyProvider1 = provider1.refresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// // RdsHostListProvider.logCache(); -// -// RdsHostListProvider provider2 = -// Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider2.init(); -// -// doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); -// -// final List topologyProvider2 = provider2.refresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider2); -// -// assertNotEquals(provider1.clusterId, provider2.clusterId); -// assertFalse(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// assertEquals(2, storageService.size(Topology.class)); -// assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", -// RdsHostListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); -// -// // RdsHostListProvider.logCache(); -// -// topologyProvider1 = provider1.forceRefresh(mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// assertEquals(provider1.clusterId, provider2.clusterId); -// assertTrue(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// -// // RdsHostListProvider.logCache(); -// } -// -// @Test -// void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// -// when(mockResultSet.next()).thenReturn(false); -// assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); -// -// when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); -// assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); -// } -// -// @Test -// void testIdentifyConnectionNullTopology() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// rdsHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("?.pattern").build(); -// -// when(mockResultSet.next()).thenReturn(true); -// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); -// doReturn(null).when(rdsHostListProvider).refresh(mockConnection); -// doReturn(null).when(rdsHostListProvider).forceRefresh(mockConnection); -// -// assertNull(rdsHostListProvider.identifyConnection(mockConnection)); -// } -// -// @Test -// void testIdentifyConnectionHostNotInTopology() throws SQLException { -// final List cachedTopology = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build()); -// -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// when(mockResultSet.next()).thenReturn(true); -// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); -// doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); -// doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); -// -// assertNull(rdsHostListProvider.identifyConnection(mockConnection)); -// } -// -// @Test -// void testIdentifyConnectionHostInTopology() throws SQLException { -// final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(); -// expectedHost.setHostId("instance-a-1"); -// final List cachedTopology = Collections.singletonList(expectedHost); -// -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// when(mockResultSet.next()).thenReturn(true); -// when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); -// doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); -// doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); -// -// final HostSpec actual = rdsHostListProvider.identifyConnection(mockConnection); -// assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); -// assertEquals("instance-a-1", actual.getHostId()); -// } -// -// @Test -// void testGetTopology_StaleRecord() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// rdsHostListProvider.isInitialized = true; -// -// final String hostName1 = "hostName1"; -// final String hostName2 = "hostName2"; -// final Double cpuUtilization = 11.1D; -// final Double nodeLag = 0.123D; -// final Timestamp firstTimestamp = Timestamp.from(Instant.now()); -// final Timestamp secondTimestamp = new Timestamp(firstTimestamp.getTime() + 100); -// when(mockResultSet.next()).thenReturn(true, true, false); -// when(mockResultSet.getString(1)).thenReturn(hostName1).thenReturn(hostName2); -// when(mockResultSet.getBoolean(2)).thenReturn(true).thenReturn(true); -// when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization).thenReturn(cpuUtilization); -// when(mockResultSet.getDouble(4)).thenReturn(nodeLag).thenReturn(nodeLag); -// when(mockResultSet.getTimestamp(5)).thenReturn(firstTimestamp).thenReturn(secondTimestamp); -// long weight = Math.round(nodeLag) * 100L + Math.round(cpuUtilization); -// final HostSpec expectedWriter = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host(hostName2) -// .port(-1) -// .role(HostRole.WRITER) -// .availability(HostAvailability.AVAILABLE) -// .weight(weight) -// .lastUpdateTime(secondTimestamp) -// .build(); -// -// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); -// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// assertEquals(1, result.hosts.size()); -// assertEquals(expectedWriter, result.hosts.get(0)); -// } -// -// @Test -// void testGetTopology_InvalidLastUpdatedTimestamp() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// rdsHostListProvider.isInitialized = true; -// -// final String hostName = "hostName"; -// final Double cpuUtilization = 11.1D; -// final Double nodeLag = 0.123D; -// when(mockResultSet.next()).thenReturn(true, false); -// when(mockResultSet.getString(1)).thenReturn(hostName); -// when(mockResultSet.getBoolean(2)).thenReturn(true); -// when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization); -// when(mockResultSet.getDouble(4)).thenReturn(nodeLag); -// when(mockResultSet.getTimestamp(5)).thenThrow(WrongArgumentException.class); -// -// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); -// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// -// final String expectedLastUpdatedTimeStampRounded = Timestamp.from(Instant.now()).toString().substring(0, 16); -// assertEquals(1, result.hosts.size()); -// assertEquals( -// expectedLastUpdatedTimeStampRounded, -// result.hosts.get(0).getLastUpdateTime().toString().substring(0, 16)); -// } -// -// @Test -// void testGetTopology_returnsLatestWriter() throws SQLException { -// rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); -// rdsHostListProvider.isInitialized = true; -// -// HostSpec expectedWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("expectedWriterHost") -// .role(HostRole.WRITER) -// .lastUpdateTime(Timestamp.valueOf("3000-01-01 00:00:00")) -// .build(); -// -// HostSpec unexpectedWriterHost0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("unexpectedWriterHost0") -// .role(HostRole.WRITER) -// .lastUpdateTime(Timestamp.valueOf("1000-01-01 00:00:00")) -// .build(); -// -// HostSpec unexpectedWriterHost1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("unexpectedWriterHost1") -// .role(HostRole.WRITER) -// .lastUpdateTime(Timestamp.valueOf("2000-01-01 00:00:00")) -// .build(); -// -// HostSpec unexpectedWriterHostWithNullLastUpdateTime0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("unexpectedWriterHostWithNullLastUpdateTime0") -// .role(HostRole.WRITER) -// .lastUpdateTime(null) -// .build(); -// -// HostSpec unexpectedWriterHostWithNullLastUpdateTime1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("unexpectedWriterHostWithNullLastUpdateTime1") -// .role(HostRole.WRITER) -// .lastUpdateTime(null) -// .build(); -// -// when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); -// -// when(mockResultSet.getString(1)).thenReturn( -// unexpectedWriterHostWithNullLastUpdateTime0.getHost(), -// unexpectedWriterHost0.getHost(), -// expectedWriterHost.getHost(), -// unexpectedWriterHost1.getHost(), -// unexpectedWriterHostWithNullLastUpdateTime1.getHost()); -// when(mockResultSet.getBoolean(2)).thenReturn(true, true, true, true, true); -// when(mockResultSet.getFloat(3)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); -// when(mockResultSet.getFloat(4)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); -// when(mockResultSet.getTimestamp(5)).thenReturn( -// unexpectedWriterHostWithNullLastUpdateTime0.getLastUpdateTime(), -// unexpectedWriterHost0.getLastUpdateTime(), -// expectedWriterHost.getLastUpdateTime(), -// unexpectedWriterHost1.getLastUpdateTime(), -// unexpectedWriterHostWithNullLastUpdateTime1.getLastUpdateTime() -// ); -// -// final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); -// verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// -// assertEquals(expectedWriterHost.getHost(), result.hosts.get(0).getHost()); -// } -// -// @Test -// void testClusterUrlUsedAsDefaultClusterId() throws SQLException { -// String readerClusterUrl = "mycluster.cluster-ro-XYZ.us-east-1.rds.amazonaws.com"; -// String expectedClusterId = "mycluster.cluster-XYZ.us-east-1.rds.amazonaws.com:1234"; -// String connectionString = "jdbc:someprotocol://" + readerClusterUrl + ":1234/test"; -// RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider(connectionString)); -// assertEquals(expectedClusterId, provider1.getClusterId()); -// -// List mockTopology = -// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build()); -// doReturn(mockTopology).when(provider1).queryForTopology(any(Connection.class)); -// provider1.refresh(); -// assertEquals(mockTopology, provider1.getStoredTopology()); -// verify(provider1, times(1)).queryForTopology(mockConnection); -// -// RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider(connectionString)); -// assertEquals(expectedClusterId, provider2.getClusterId()); -// assertEquals(mockTopology, provider2.getStoredTopology()); -// verify(provider2, never()).queryForTopology(mockConnection); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.hostlistprovider; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atMostOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.mysql.cj.exceptions.WrongArgumentException; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLSyntaxErrorException; +import java.sql.Statement; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.events.EventPublisher; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.storage.TestStorageServiceImpl; + +class RdsHostListProviderTest { + private StorageService storageService; + private RdsHostListProvider rdsHostListProvider; + + @Mock private Connection mockConnection; + @Mock private Statement mockStatement; + @Mock private ResultSet mockResultSet; + @Mock private FullServicesContainer mockServicesContainer; + @Mock private PluginService mockPluginService; + @Mock private HostListProviderService mockHostListProviderService; + @Mock private EventPublisher mockEventPublisher; + @Mock Dialect mockTopologyAwareDialect; + @Captor private ArgumentCaptor queryCaptor; + + private AutoCloseable closeable; + private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("foo").port(1234).build(); + private final List hosts = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + storageService = new TestStorageServiceImpl(mockEventPublisher); + when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); + when(mockServicesContainer.getStorageService()).thenReturn(storageService); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); + when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); + when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); + when(mockHostListProviderService.getHostSpecBuilder()) + .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); + when(mockHostListProviderService.getCurrentConnection()).thenReturn(mockConnection); + } + + @AfterEach + void tearDown() throws Exception { + RdsHostListProvider.clearAll(); + storageService.clearAll(); + closeable.close(); + } + + private RdsHostListProvider getRdsHostListProvider(String originalUrl) throws SQLException { + RdsHostListProvider provider = new RdsHostListProvider( + new Properties(), + originalUrl, + mockServicesContainer, + "foo", "bar", "baz"); + provider.init(); + return provider; + } + + @Test + void testGetTopology_returnCachedTopology() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("protocol://url/")); + + final List expected = hosts; + storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); + + final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); + assertEquals(expected, result.hosts); + assertEquals(2, result.hosts.size()); + verify(rdsHostListProvider, never()).queryForTopology(mockConnection); + } + + @Test + void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + rdsHostListProvider.isInitialized = true; + + storageService.set(rdsHostListProvider.clusterId, new Topology(hosts)); + + final List newHosts = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); + doReturn(newHosts).when(rdsHostListProvider).queryForTopology(mockConnection); + + final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); + verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); + assertEquals(1, result.hosts.size()); + assertEquals(newHosts, result.hosts); + } + + @Test + void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + rdsHostListProvider.clusterId = "cluster-id"; + rdsHostListProvider.isInitialized = true; + + final List expected = hosts; + storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); + + doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); + + final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, false); + verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); + assertEquals(2, result.hosts.size()); + assertEquals(expected, result.hosts); + } + + @Test + void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + rdsHostListProvider.clear(); + + doReturn(new ArrayList<>()).when(rdsHostListProvider).queryForTopology(mockConnection); + + final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); + verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); + assertNotNull(result.hosts); + assertEquals( + Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), + result.hosts); + } + + @Test + void testQueryForTopology_withDifferentDriverProtocol() throws SQLException { + final List expectedMySQL = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("mysql").port(HostSpec.NO_PORT) + .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); + final List expectedPostgres = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("postgresql").port(HostSpec.NO_PORT) + .role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).weight(0).build()); + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getBoolean(eq(2))).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("mysql"); + + + rdsHostListProvider = getRdsHostListProvider("mysql://url/"); + + List hosts = rdsHostListProvider.queryForTopology(mockConnection); + assertEquals(expectedMySQL, hosts); + + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString(eq(1))).thenReturn("postgresql"); + + rdsHostListProvider = getRdsHostListProvider("postgresql://url/"); + hosts = rdsHostListProvider.queryForTopology(mockConnection); + assertEquals(expectedPostgres, hosts); + } + + @Test + void testQueryForTopology_queryResultsInException() throws SQLException { + rdsHostListProvider = getRdsHostListProvider("protocol://url/"); + when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); + + assertThrows( + SQLException.class, + () -> rdsHostListProvider.queryForTopology(mockConnection)); + } + + @Test + void testGetCachedTopology_returnStoredTopology() throws SQLException { + rdsHostListProvider = getRdsHostListProvider("jdbc:someprotocol://url"); + + final List expected = hosts; + storageService.set(rdsHostListProvider.clusterId, new Topology(expected)); + + final List result = rdsHostListProvider.getStoredTopology(); + assertEquals(expected, result); + } + + @Test + void testTopologyCache_NoSuggestedClusterId() throws SQLException { + RdsHostListProvider.clearAll(); + + RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.domain.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); + + doReturn(topologyClusterA) + .when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + final List topologyProvider1 = provider1.refresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-b.domain.com/")); + provider2.init(); + assertNull(provider2.getStoredTopology()); + + final List topologyClusterB = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); + doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); + + final List topologyProvider2 = provider2.refresh(mock(Connection.class)); + assertEquals(topologyClusterB, topologyProvider2); + + assertEquals(2, storageService.size(Topology.class)); + } + + @Test + void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { + RdsHostListProvider.clearAll(); + + RdsHostListProvider provider1 = + Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build()); + + doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + final List topologyProvider1 = provider1.refresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + RdsHostListProvider provider2 = + Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider2.init(); + + assertEquals(provider1.clusterId, provider2.clusterId); + assertTrue(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + + final List topologyProvider2 = provider2.refresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider2); + + assertEquals(1, storageService.size(Topology.class)); + } + + @Test + void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { + RdsHostListProvider.clearAll(); + + RdsHostListProvider provider1 = + Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build()); + + doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + final List topologyProvider1 = provider1.refresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + RdsHostListProvider provider2 = + Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); + provider2.init(); + + assertEquals(provider1.clusterId, provider2.clusterId); + assertTrue(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + + final List topologyProvider2 = provider2.refresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider2); + + assertEquals(1, storageService.size(Topology.class)); + } + + @Test + void testTopologyCache_AcceptSuggestion() throws SQLException { + RdsHostListProvider.clearAll(); + + RdsHostListProvider provider1 = + Mockito.spy(getRdsHostListProvider("jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build()); + + doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + List topologyProvider1 = provider1.refresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + // RdsHostListProvider.logCache(); + + RdsHostListProvider provider2 = + Mockito.spy(getRdsHostListProvider("jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider2.init(); + + doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); + + final List topologyProvider2 = provider2.refresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider2); + + assertNotEquals(provider1.clusterId, provider2.clusterId); + assertFalse(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + assertEquals(2, storageService.size(Topology.class)); + assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", + RdsHostListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); + + // RdsHostListProvider.logCache(); + + topologyProvider1 = provider1.forceRefresh(mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + assertEquals(provider1.clusterId, provider2.clusterId); + assertTrue(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + + // RdsHostListProvider.logCache(); + } + + @Test + void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + + when(mockResultSet.next()).thenReturn(false); + assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); + + when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); + assertThrows(SQLException.class, () -> rdsHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionNullTopology() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + rdsHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("?.pattern").build(); + + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); + doReturn(null).when(rdsHostListProvider).refresh(mockConnection); + doReturn(null).when(rdsHostListProvider).forceRefresh(mockConnection); + + assertNull(rdsHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionHostNotInTopology() throws SQLException { + final List cachedTopology = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build()); + + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); + doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); + doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); + + assertNull(rdsHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionHostInTopology() throws SQLException { + final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(); + expectedHost.setHostId("instance-a-1"); + final List cachedTopology = Collections.singletonList(expectedHost); + + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); + doReturn(cachedTopology).when(rdsHostListProvider).refresh(mockConnection); + doReturn(cachedTopology).when(rdsHostListProvider).forceRefresh(mockConnection); + + final HostSpec actual = rdsHostListProvider.identifyConnection(mockConnection); + assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); + assertEquals("instance-a-1", actual.getHostId()); + } + + @Test + void testGetTopology_StaleRecord() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + rdsHostListProvider.isInitialized = true; + + final String hostName1 = "hostName1"; + final String hostName2 = "hostName2"; + final Double cpuUtilization = 11.1D; + final Double nodeLag = 0.123D; + final Timestamp firstTimestamp = Timestamp.from(Instant.now()); + final Timestamp secondTimestamp = new Timestamp(firstTimestamp.getTime() + 100); + when(mockResultSet.next()).thenReturn(true, true, false); + when(mockResultSet.getString(1)).thenReturn(hostName1).thenReturn(hostName2); + when(mockResultSet.getBoolean(2)).thenReturn(true).thenReturn(true); + when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization).thenReturn(cpuUtilization); + when(mockResultSet.getDouble(4)).thenReturn(nodeLag).thenReturn(nodeLag); + when(mockResultSet.getTimestamp(5)).thenReturn(firstTimestamp).thenReturn(secondTimestamp); + long weight = Math.round(nodeLag) * 100L + Math.round(cpuUtilization); + final HostSpec expectedWriter = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host(hostName2) + .port(-1) + .role(HostRole.WRITER) + .availability(HostAvailability.AVAILABLE) + .weight(weight) + .lastUpdateTime(secondTimestamp) + .build(); + + final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); + verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); + assertEquals(1, result.hosts.size()); + assertEquals(expectedWriter, result.hosts.get(0)); + } + + @Test + void testGetTopology_InvalidLastUpdatedTimestamp() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + rdsHostListProvider.isInitialized = true; + + final String hostName = "hostName"; + final Double cpuUtilization = 11.1D; + final Double nodeLag = 0.123D; + when(mockResultSet.next()).thenReturn(true, false); + when(mockResultSet.getString(1)).thenReturn(hostName); + when(mockResultSet.getBoolean(2)).thenReturn(true); + when(mockResultSet.getDouble(3)).thenReturn(cpuUtilization); + when(mockResultSet.getDouble(4)).thenReturn(nodeLag); + when(mockResultSet.getTimestamp(5)).thenThrow(WrongArgumentException.class); + + final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); + verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); + + final String expectedLastUpdatedTimeStampRounded = Timestamp.from(Instant.now()).toString().substring(0, 16); + assertEquals(1, result.hosts.size()); + assertEquals( + expectedLastUpdatedTimeStampRounded, + result.hosts.get(0).getLastUpdateTime().toString().substring(0, 16)); + } + + @Test + void testGetTopology_returnsLatestWriter() throws SQLException { + rdsHostListProvider = Mockito.spy(getRdsHostListProvider("jdbc:someprotocol://url")); + rdsHostListProvider.isInitialized = true; + + HostSpec expectedWriterHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("expectedWriterHost") + .role(HostRole.WRITER) + .lastUpdateTime(Timestamp.valueOf("3000-01-01 00:00:00")) + .build(); + + HostSpec unexpectedWriterHost0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("unexpectedWriterHost0") + .role(HostRole.WRITER) + .lastUpdateTime(Timestamp.valueOf("1000-01-01 00:00:00")) + .build(); + + HostSpec unexpectedWriterHost1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("unexpectedWriterHost1") + .role(HostRole.WRITER) + .lastUpdateTime(Timestamp.valueOf("2000-01-01 00:00:00")) + .build(); + + HostSpec unexpectedWriterHostWithNullLastUpdateTime0 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("unexpectedWriterHostWithNullLastUpdateTime0") + .role(HostRole.WRITER) + .lastUpdateTime(null) + .build(); + + HostSpec unexpectedWriterHostWithNullLastUpdateTime1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("unexpectedWriterHostWithNullLastUpdateTime1") + .role(HostRole.WRITER) + .lastUpdateTime(null) + .build(); + + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + + when(mockResultSet.getString(1)).thenReturn( + unexpectedWriterHostWithNullLastUpdateTime0.getHost(), + unexpectedWriterHost0.getHost(), + expectedWriterHost.getHost(), + unexpectedWriterHost1.getHost(), + unexpectedWriterHostWithNullLastUpdateTime1.getHost()); + when(mockResultSet.getBoolean(2)).thenReturn(true, true, true, true, true); + when(mockResultSet.getFloat(3)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); + when(mockResultSet.getFloat(4)).thenReturn((float) 0, (float) 0, (float) 0, (float) 0, (float) 0); + when(mockResultSet.getTimestamp(5)).thenReturn( + unexpectedWriterHostWithNullLastUpdateTime0.getLastUpdateTime(), + unexpectedWriterHost0.getLastUpdateTime(), + expectedWriterHost.getLastUpdateTime(), + unexpectedWriterHost1.getLastUpdateTime(), + unexpectedWriterHostWithNullLastUpdateTime1.getLastUpdateTime() + ); + + final FetchTopologyResult result = rdsHostListProvider.getTopology(mockConnection, true); + verify(rdsHostListProvider, atMostOnce()).queryForTopology(mockConnection); + + assertEquals(expectedWriterHost.getHost(), result.hosts.get(0).getHost()); + } + + @Test + void testClusterUrlUsedAsDefaultClusterId() throws SQLException { + String readerClusterUrl = "mycluster.cluster-ro-XYZ.us-east-1.rds.amazonaws.com"; + String expectedClusterId = "mycluster.cluster-XYZ.us-east-1.rds.amazonaws.com:1234"; + String connectionString = "jdbc:someprotocol://" + readerClusterUrl + ":1234/test"; + RdsHostListProvider provider1 = Mockito.spy(getRdsHostListProvider(connectionString)); + assertEquals(expectedClusterId, provider1.getClusterId()); + + List mockTopology = + Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build()); + doReturn(mockTopology).when(provider1).queryForTopology(any(Connection.class)); + provider1.refresh(); + assertEquals(mockTopology, provider1.getStoredTopology()); + verify(provider1, times(1)).queryForTopology(mockConnection); + + RdsHostListProvider provider2 = Mockito.spy(getRdsHostListProvider(connectionString)); + assertEquals(expectedClusterId, provider2.getClusterId()); + assertEquals(mockTopology, provider2.getStoredTopology()); + verify(provider2, never()).queryForTopology(mockConnection); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java index db5e10c62..df6d6ee50 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java @@ -1,470 +1,470 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.hostlistprovider; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertFalse; -// import static org.junit.jupiter.api.Assertions.assertNotEquals; -// import static org.junit.jupiter.api.Assertions.assertNotNull; -// import static org.junit.jupiter.api.Assertions.assertNull; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.junit.jupiter.api.Assertions.assertTrue; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.atMostOnce; -// import static org.mockito.Mockito.doAnswer; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.SQLSyntaxErrorException; -// import java.sql.Statement; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.List; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.ArgumentCaptor; -// import org.mockito.Captor; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.events.EventPublisher; -// import software.amazon.jdbc.util.storage.StorageService; -// import software.amazon.jdbc.util.storage.TestStorageServiceImpl; -// -// class RdsMultiAzDbClusterListProviderTest { -// private StorageService storageService; -// private RdsMultiAzDbClusterListProvider rdsMazDbClusterHostListProvider; -// -// @Mock private Connection mockConnection; -// @Mock private Statement mockStatement; -// @Mock private ResultSet mockResultSet; -// @Mock private FullServicesContainer mockServicesContainer; -// @Mock private PluginService mockPluginService; -// @Mock private HostListProviderService mockHostListProviderService; -// @Mock private EventPublisher mockEventPublisher; -// @Mock Dialect mockTopologyAwareDialect; -// @Captor private ArgumentCaptor queryCaptor; -// -// private AutoCloseable closeable; -// private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("foo").port(1234).build(); -// private final List hosts = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); -// -// @BeforeEach -// void setUp() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// storageService = new TestStorageServiceImpl(mockEventPublisher); -// when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); -// when(mockServicesContainer.getStorageService()).thenReturn(storageService); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); -// when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); -// when(mockConnection.createStatement()).thenReturn(mockStatement); -// when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); -// when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); -// when(mockHostListProviderService.getHostSpecBuilder()) -// .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// RdsMultiAzDbClusterListProvider.clearAll(); -// storageService.clearAll(); -// closeable.close(); -// } -// -// private RdsMultiAzDbClusterListProvider getRdsMazDbClusterHostListProvider(String originalUrl) throws SQLException { -// RdsMultiAzDbClusterListProvider provider = new RdsMultiAzDbClusterListProvider( -// new Properties(), -// originalUrl, -// mockServicesContainer, -// "foo", -// "bar", -// "baz", -// "fang", -// "li"); -// provider.init(); -// // provider.clusterId = "cluster-id"; -// return provider; -// } -// -// @Test -// void testGetTopology_returnCachedTopology() throws SQLException { -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("protocol://url/")); -// final List expected = hosts; -// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); -// -// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); -// assertEquals(expected, result.hosts); -// assertEquals(2, result.hosts.size()); -// verify(rdsMazDbClusterHostListProvider, never()).queryForTopology(mockConnection); -// } -// -// @Test -// void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); -// rdsMazDbClusterHostListProvider.isInitialized = true; -// -// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(hosts)); -// -// final List newHosts = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); -// doReturn(newHosts).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); -// -// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); -// verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// assertEquals(1, result.hosts.size()); -// assertEquals(newHosts, result.hosts); -// } -// -// @Test -// void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); -// rdsMazDbClusterHostListProvider.clusterId = "cluster-id"; -// rdsMazDbClusterHostListProvider.isInitialized = true; -// -// final List expected = hosts; -// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); -// -// doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); -// -// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); -// verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// assertEquals(2, result.hosts.size()); -// assertEquals(expected, result.hosts); -// } -// -// @Test -// void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); -// rdsMazDbClusterHostListProvider.clear(); -// -// doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); -// -// final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); -// verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); -// assertNotNull(result.hosts); -// assertEquals( -// Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), -// result.hosts); -// } -// -// @Test -// void testQueryForTopology_queryResultsInException() throws SQLException { -// rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("protocol://url/"); -// when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); -// -// assertThrows( -// SQLException.class, -// () -> rdsMazDbClusterHostListProvider.queryForTopology(mockConnection)); -// } -// -// @Test -// void testGetCachedTopology_returnCachedTopology() throws SQLException { -// rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url"); -// -// final List expected = hosts; -// storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); -// -// final List result = rdsMazDbClusterHostListProvider.getStoredTopology(); -// assertEquals(expected, result); -// } -// -// @Test -// void testTopologyCache_NoSuggestedClusterId() throws SQLException { -// RdsMultiAzDbClusterListProvider.clearAll(); -// -// RdsMultiAzDbClusterListProvider provider1 = -// Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-a.domain.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); -// -// doReturn(topologyClusterA) -// .when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// RdsMultiAzDbClusterListProvider provider2 = -// Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-b.domain.com/")); -// provider2.init(); -// assertNull(provider2.getStoredTopology()); -// -// final List topologyClusterB = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); -// doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); -// -// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterB, topologyProvider2); -// -// assertEquals(2, storageService.size(Topology.class)); -// } -// -// @Test -// void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { -// RdsMultiAzDbClusterListProvider.clearAll(); -// -// RdsMultiAzDbClusterListProvider provider1 = -// Mockito.spy(getRdsMazDbClusterHostListProvider( -// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build()); -// -// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// RdsMultiAzDbClusterListProvider provider2 = -// Mockito.spy(getRdsMazDbClusterHostListProvider( -// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider2.init(); -// -// assertEquals(provider1.clusterId, provider2.clusterId); -// assertTrue(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// -// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider2); -// -// assertEquals(1, storageService.size(Topology.class)); -// } -// -// @Test -// void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { -// RdsMultiAzDbClusterListProvider.clearAll(); -// -// RdsMultiAzDbClusterListProvider provider1 = -// Mockito.spy(getRdsMazDbClusterHostListProvider( -// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build()); -// -// doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// RdsMultiAzDbClusterListProvider provider2 = -// Mockito.spy(getRdsMazDbClusterHostListProvider( -// "jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); -// provider2.init(); -// -// assertEquals(provider1.clusterId, provider2.clusterId); -// assertTrue(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// -// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider2); -// -// assertEquals(1, storageService.size(Topology.class)); -// } -// -// @Test -// void testTopologyCache_AcceptSuggestion() throws SQLException { -// RdsMultiAzDbClusterListProvider.clearAll(); -// -// RdsMultiAzDbClusterListProvider provider1 = -// Mockito.spy(getRdsMazDbClusterHostListProvider( -// "jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); -// provider1.init(); -// final List topologyClusterA = Arrays.asList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build(), -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.READER) -// .build()); -// -// doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); -// -// assertEquals(0, storageService.size(Topology.class)); -// -// List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// -// // RdsMultiAzDbClusterListProvider.logCache(); -// -// RdsMultiAzDbClusterListProvider provider2 = -// Mockito.spy(getRdsMazDbClusterHostListProvider( -// "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); -// provider2.init(); -// -// doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); -// -// final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider2); -// -// assertNotEquals(provider1.clusterId, provider2.clusterId); -// assertFalse(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// assertEquals(2, storageService.size(Topology.class)); -// assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", -// RdsMultiAzDbClusterListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); -// -// // RdsMultiAzDbClusterListProvider.logCache(); -// -// topologyProvider1 = provider1.forceRefresh(Mockito.mock(Connection.class)); -// assertEquals(topologyClusterA, topologyProvider1); -// assertEquals(provider1.clusterId, provider2.clusterId); -// assertTrue(provider1.isPrimaryClusterId); -// assertTrue(provider2.isPrimaryClusterId); -// -// // RdsMultiAzDbClusterListProvider.logCache(); -// } -// -// @Test -// void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); -// -// when(mockResultSet.next()).thenReturn(false); -// assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); -// -// when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); -// assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); -// } -// -// @Test -// void testIdentifyConnectionNullTopology() throws SQLException { -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); -// rdsMazDbClusterHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("?.pattern").build(); -// -// when(mockResultSet.next()).thenReturn(true); -// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); -// doReturn(null).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); -// doReturn(null).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); -// -// assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); -// } -// -// @Test -// void testIdentifyConnectionHostNotInTopology() throws SQLException { -// final List cachedTopology = Collections.singletonList( -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build()); -// -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); -// when(mockResultSet.next()).thenReturn(true); -// when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); -// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); -// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); -// -// assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); -// } -// -// @Test -// void testIdentifyConnectionHostInTopology() throws SQLException { -// final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") -// .hostId("instance-a-1") -// .port(HostSpec.NO_PORT) -// .role(HostRole.WRITER) -// .build(); -// final List cachedTopology = Collections.singletonList(expectedHost); -// -// rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); -// when(mockResultSet.next()).thenReturn(true); -// when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); -// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); -// doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); -// -// final HostSpec actual = rdsMazDbClusterHostListProvider.identifyConnection(mockConnection); -// assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); -// assertEquals("instance-a-1", actual.getHostId()); -// } -// -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.hostlistprovider; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atMostOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLSyntaxErrorException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.events.EventPublisher; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.storage.TestStorageServiceImpl; + +class RdsMultiAzDbClusterListProviderTest { + private StorageService storageService; + private RdsMultiAzDbClusterListProvider rdsMazDbClusterHostListProvider; + + @Mock private Connection mockConnection; + @Mock private Statement mockStatement; + @Mock private ResultSet mockResultSet; + @Mock private FullServicesContainer mockServicesContainer; + @Mock private PluginService mockPluginService; + @Mock private HostListProviderService mockHostListProviderService; + @Mock private EventPublisher mockEventPublisher; + @Mock Dialect mockTopologyAwareDialect; + @Captor private ArgumentCaptor queryCaptor; + + private AutoCloseable closeable; + private final HostSpec currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("foo").port(1234).build(); + private final List hosts = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host1").build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2").build()); + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + storageService = new TestStorageServiceImpl(mockEventPublisher); + when(mockServicesContainer.getHostListProviderService()).thenReturn(mockHostListProviderService); + when(mockServicesContainer.getStorageService()).thenReturn(storageService); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.connect(any(HostSpec.class), any(Properties.class))).thenReturn(mockConnection); + when(mockPluginService.getCurrentHostSpec()).thenReturn(currentHostSpec); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); + when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); + when(mockHostListProviderService.getHostSpecBuilder()) + .thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); + } + + @AfterEach + void tearDown() throws Exception { + RdsMultiAzDbClusterListProvider.clearAll(); + storageService.clearAll(); + closeable.close(); + } + + private RdsMultiAzDbClusterListProvider getRdsMazDbClusterHostListProvider(String originalUrl) throws SQLException { + RdsMultiAzDbClusterListProvider provider = new RdsMultiAzDbClusterListProvider( + new Properties(), + originalUrl, + mockServicesContainer, + "foo", + "bar", + "baz", + "fang", + "li"); + provider.init(); + // provider.clusterId = "cluster-id"; + return provider; + } + + @Test + void testGetTopology_returnCachedTopology() throws SQLException { + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("protocol://url/")); + final List expected = hosts; + storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); + + final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); + assertEquals(expected, result.hosts); + assertEquals(2, result.hosts.size()); + verify(rdsMazDbClusterHostListProvider, never()).queryForTopology(mockConnection); + } + + @Test + void testGetTopology_withForceUpdate_returnsUpdatedTopology() throws SQLException { + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); + rdsMazDbClusterHostListProvider.isInitialized = true; + + storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(hosts)); + + final List newHosts = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("newHost").build()); + doReturn(newHosts).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); + + final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); + verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); + assertEquals(1, result.hosts.size()); + assertEquals(newHosts, result.hosts); + } + + @Test + void testGetTopology_noForceUpdate_queryReturnsEmptyHostList() throws SQLException { + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); + rdsMazDbClusterHostListProvider.clusterId = "cluster-id"; + rdsMazDbClusterHostListProvider.isInitialized = true; + + final List expected = hosts; + storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); + + doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); + + final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, false); + verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); + assertEquals(2, result.hosts.size()); + assertEquals(expected, result.hosts); + } + + @Test + void testGetTopology_withForceUpdate_returnsInitialHostList() throws SQLException { + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); + rdsMazDbClusterHostListProvider.clear(); + + doReturn(new ArrayList<>()).when(rdsMazDbClusterHostListProvider).queryForTopology(mockConnection); + + final FetchTopologyResult result = rdsMazDbClusterHostListProvider.getTopology(mockConnection, true); + verify(rdsMazDbClusterHostListProvider, atMostOnce()).queryForTopology(mockConnection); + assertNotNull(result.hosts); + assertEquals( + Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("url").build()), + result.hosts); + } + + @Test + void testQueryForTopology_queryResultsInException() throws SQLException { + rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("protocol://url/"); + when(mockStatement.executeQuery(queryCaptor.capture())).thenThrow(new SQLSyntaxErrorException()); + + assertThrows( + SQLException.class, + () -> rdsMazDbClusterHostListProvider.queryForTopology(mockConnection)); + } + + @Test + void testGetCachedTopology_returnCachedTopology() throws SQLException { + rdsMazDbClusterHostListProvider = getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url"); + + final List expected = hosts; + storageService.set(rdsMazDbClusterHostListProvider.clusterId, new Topology(expected)); + + final List result = rdsMazDbClusterHostListProvider.getStoredTopology(); + assertEquals(expected, result); + } + + @Test + void testTopologyCache_NoSuggestedClusterId() throws SQLException { + RdsMultiAzDbClusterListProvider.clearAll(); + + RdsMultiAzDbClusterListProvider provider1 = + Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-a.domain.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); + + doReturn(topologyClusterA) + .when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + RdsMultiAzDbClusterListProvider provider2 = + Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:something://cluster-b.domain.com/")); + provider2.init(); + assertNull(provider2.getStoredTopology()); + + final List topologyClusterB = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-b-1.domain.com").port(HostSpec.NO_PORT).role(HostRole.WRITER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-b-2.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-b-3.domain.com").port(HostSpec.NO_PORT).role(HostRole.READER).build()); + doReturn(topologyClusterB).when(provider2).queryForTopology(any(Connection.class)); + + final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterB, topologyProvider2); + + assertEquals(2, storageService.size(Topology.class)); + } + + @Test + void testTopologyCache_SuggestedClusterIdForRds() throws SQLException { + RdsMultiAzDbClusterListProvider.clearAll(); + + RdsMultiAzDbClusterListProvider provider1 = + Mockito.spy(getRdsMazDbClusterHostListProvider( + "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build()); + + doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + RdsMultiAzDbClusterListProvider provider2 = + Mockito.spy(getRdsMazDbClusterHostListProvider( + "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider2.init(); + + assertEquals(provider1.clusterId, provider2.clusterId); + assertTrue(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + + final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider2); + + assertEquals(1, storageService.size(Topology.class)); + } + + @Test + void testTopologyCache_SuggestedClusterIdForInstance() throws SQLException { + RdsMultiAzDbClusterListProvider.clearAll(); + + RdsMultiAzDbClusterListProvider provider1 = + Mockito.spy(getRdsMazDbClusterHostListProvider( + "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build()); + + doReturn(topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + final List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + RdsMultiAzDbClusterListProvider provider2 = + Mockito.spy(getRdsMazDbClusterHostListProvider( + "jdbc:something://instance-a-3.xyz.us-east-2.rds.amazonaws.com/")); + provider2.init(); + + assertEquals(provider1.clusterId, provider2.clusterId); + assertTrue(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + + final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider2); + + assertEquals(1, storageService.size(Topology.class)); + } + + @Test + void testTopologyCache_AcceptSuggestion() throws SQLException { + RdsMultiAzDbClusterListProvider.clearAll(); + + RdsMultiAzDbClusterListProvider provider1 = + Mockito.spy(getRdsMazDbClusterHostListProvider( + "jdbc:something://instance-a-2.xyz.us-east-2.rds.amazonaws.com/")); + provider1.init(); + final List topologyClusterA = Arrays.asList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-2.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build(), + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-3.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.READER) + .build()); + + doAnswer(a -> topologyClusterA).when(provider1).queryForTopology(any(Connection.class)); + + assertEquals(0, storageService.size(Topology.class)); + + List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + + // RdsMultiAzDbClusterListProvider.logCache(); + + RdsMultiAzDbClusterListProvider provider2 = + Mockito.spy(getRdsMazDbClusterHostListProvider( + "jdbc:something://cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/")); + provider2.init(); + + doAnswer(a -> topologyClusterA).when(provider2).queryForTopology(any(Connection.class)); + + final List topologyProvider2 = provider2.refresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider2); + + assertNotEquals(provider1.clusterId, provider2.clusterId); + assertFalse(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + assertEquals(2, storageService.size(Topology.class)); + assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com", + RdsMultiAzDbClusterListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); + + // RdsMultiAzDbClusterListProvider.logCache(); + + topologyProvider1 = provider1.forceRefresh(Mockito.mock(Connection.class)); + assertEquals(topologyClusterA, topologyProvider1); + assertEquals(provider1.clusterId, provider2.clusterId); + assertTrue(provider1.isPrimaryClusterId); + assertTrue(provider2.isPrimaryClusterId); + + // RdsMultiAzDbClusterListProvider.logCache(); + } + + @Test + void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); + + when(mockResultSet.next()).thenReturn(false); + assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); + + when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); + assertThrows(SQLException.class, () -> rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionNullTopology() throws SQLException { + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); + rdsMazDbClusterHostListProvider.clusterInstanceTemplate = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("?.pattern").build(); + + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); + doReturn(null).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); + doReturn(null).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); + + assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionHostNotInTopology() throws SQLException { + final List cachedTopology = Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build()); + + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); + doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); + doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); + + assertNull(rdsMazDbClusterHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionHostInTopology() throws SQLException { + final HostSpec expectedHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-a-1.xyz.us-east-2.rds.amazonaws.com") + .hostId("instance-a-1") + .port(HostSpec.NO_PORT) + .role(HostRole.WRITER) + .build(); + final List cachedTopology = Collections.singletonList(expectedHost); + + rdsMazDbClusterHostListProvider = Mockito.spy(getRdsMazDbClusterHostListProvider("jdbc:someprotocol://url")); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); + doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).refresh(mockConnection); + doReturn(cachedTopology).when(rdsMazDbClusterHostListProvider).forceRefresh(mockConnection); + + final HostSpec actual = rdsMazDbClusterHostListProvider.identifyConnection(mockConnection); + assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); + assertEquals("instance-a-1", actual.getHostId()); + } + +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java index 6c0d439c0..d92930cd9 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java @@ -1,155 +1,155 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.mock; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.EnumSet; -// import java.util.HashSet; -// import java.util.List; -// import java.util.Map; -// import java.util.Properties; -// import java.util.Set; -// import software.amazon.jdbc.ConnectionPlugin; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.NodeChangeOptions; -// import software.amazon.jdbc.OldConnectionSuggestedAction; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.util.connection.ConnectionContext; -// -// public class TestPluginOne implements ConnectionPlugin { -// -// protected Set subscribedMethods; -// protected ArrayList calls; -// -// TestPluginOne() {} -// -// public TestPluginOne(ArrayList calls) { -// this.calls = calls; -// -// this.subscribedMethods = new HashSet<>(Arrays.asList("*")); -// } -// -// @Override -// public Set getSubscribedMethods() { -// return this.subscribedMethods; -// } -// -// @Override -// public T execute( -// Class resultClass, -// Class exceptionClass, -// Object methodInvokeOn, -// String methodName, -// JdbcCallable jdbcMethodFunc, -// Object[] jdbcMethodArgs) -// throws E { -// -// this.calls.add(this.getClass().getSimpleName() + ":before"); -// -// T result; -// try { -// result = jdbcMethodFunc.call(); -// } catch (RuntimeException e) { -// throw e; -// } catch (Exception e) { -// if (exceptionClass.isInstance(e)) { -// throw exceptionClass.cast(e); -// } -// throw new RuntimeException(e); -// } -// -// this.calls.add(this.getClass().getSimpleName() + ":after"); -// -// return result; -// } -// -// @Override -// public Connection connect( -// final ConnectionContext connectionContext, -// final HostSpec hostSpec, -// final boolean isInitialConnection, -// final JdbcCallable connectFunc) throws SQLException { -// -// this.calls.add(this.getClass().getSimpleName() + ":before connect"); -// Connection result = connectFunc.call(); -// this.calls.add(this.getClass().getSimpleName() + ":after connect"); -// return result; -// } -// -// @Override -// public Connection forceConnect( -// String driverProtocol, -// HostSpec hostSpec, -// Properties props, -// boolean isInitialConnection, -// JdbcCallable forceConnectFunc) -// throws SQLException { -// -// this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); -// Connection result = forceConnectFunc.call(); -// this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); -// return result; -// } -// -// @Override -// public boolean acceptsStrategy(HostRole role, String strategy) { -// return false; -// } -// -// @Override -// public HostSpec getHostSpecByStrategy(HostRole role, String strategy) { -// this.calls.add(this.getClass().getSimpleName() + ":before getHostSpecByStrategy"); -// HostSpec result = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("host").port(1234).role(role).build(); -// this.calls.add(this.getClass().getSimpleName() + ":after getHostSpecByStrategy"); -// return result; -// } -// -// @Override -// public HostSpec getHostSpecByStrategy(List hosts, HostRole role, String strategy) { -// return getHostSpecByStrategy(role, strategy); -// } -// -// @Override -// public void initHostProvider( -// String driverProtocol, -// String initialUrl, -// Properties props, -// HostListProviderService hostListProviderService, -// JdbcCallable initHostProviderFunc) -// throws SQLException { -// -// // do nothing -// } -// -// @Override -// public OldConnectionSuggestedAction notifyConnectionChanged(EnumSet changes) { -// return OldConnectionSuggestedAction.NO_OPINION; -// } -// -// @Override -// public void notifyNodeListChanged(Map> changes) { -// // do nothing -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.mock; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.NodeChangeOptions; +import software.amazon.jdbc.OldConnectionSuggestedAction; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.connection.ConnectionContext; + +public class TestPluginOne implements ConnectionPlugin { + + protected Set subscribedMethods; + protected ArrayList calls; + + TestPluginOne() {} + + public TestPluginOne(ArrayList calls) { + this.calls = calls; + + this.subscribedMethods = new HashSet<>(Arrays.asList("*")); + } + + @Override + public Set getSubscribedMethods() { + return this.subscribedMethods; + } + + @Override + public T execute( + Class resultClass, + Class exceptionClass, + Object methodInvokeOn, + String methodName, + JdbcCallable jdbcMethodFunc, + Object[] jdbcMethodArgs) + throws E { + + this.calls.add(this.getClass().getSimpleName() + ":before"); + + T result; + try { + result = jdbcMethodFunc.call(); + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + if (exceptionClass.isInstance(e)) { + throw exceptionClass.cast(e); + } + throw new RuntimeException(e); + } + + this.calls.add(this.getClass().getSimpleName() + ":after"); + + return result; + } + + @Override + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { + + this.calls.add(this.getClass().getSimpleName() + ":before connect"); + Connection result = connectFunc.call(); + this.calls.add(this.getClass().getSimpleName() + ":after connect"); + return result; + } + + @Override + public Connection forceConnect( + String driverProtocol, + HostSpec hostSpec, + Properties props, + boolean isInitialConnection, + JdbcCallable forceConnectFunc) + throws SQLException { + + this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); + Connection result = forceConnectFunc.call(); + this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); + return result; + } + + @Override + public boolean acceptsStrategy(HostRole role, String strategy) { + return false; + } + + @Override + public HostSpec getHostSpecByStrategy(HostRole role, String strategy) { + this.calls.add(this.getClass().getSimpleName() + ":before getHostSpecByStrategy"); + HostSpec result = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("host").port(1234).role(role).build(); + this.calls.add(this.getClass().getSimpleName() + ":after getHostSpecByStrategy"); + return result; + } + + @Override + public HostSpec getHostSpecByStrategy(List hosts, HostRole role, String strategy) { + return getHostSpecByStrategy(role, strategy); + } + + @Override + public void initHostProvider( + String driverProtocol, + String initialUrl, + Properties props, + HostListProviderService hostListProviderService, + JdbcCallable initHostProviderFunc) + throws SQLException { + + // do nothing + } + + @Override + public OldConnectionSuggestedAction notifyConnectionChanged(EnumSet changes) { + return OldConnectionSuggestedAction.NO_OPINION; + } + + @Override + public void notifyNodeListChanged(Map> changes) { + // do nothing + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java index f62a93499..876444bc8 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java @@ -1,87 +1,87 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.mock; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.HashSet; -// import java.util.Properties; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.JdbcMethod; -// import software.amazon.jdbc.util.connection.ConnectionContext; -// -// public class TestPluginThree extends TestPluginOne { -// -// private Connection connection; -// -// public TestPluginThree(ArrayList calls) { -// super(); -// this.calls = calls; -// -// this.subscribedMethods = new HashSet<>(Arrays.asList( -// JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.CONNECT.methodName, JdbcMethod.FORCECONNECT.methodName)); -// } -// -// public TestPluginThree(ArrayList calls, Connection connection) { -// this(calls); -// this.connection = connection; -// } -// -// @Override -// public Connection connect( -// final ConnectionContext connectionContext, -// final HostSpec hostSpec, -// final boolean isInitialConnection, -// final JdbcCallable connectFunc) throws SQLException { -// -// this.calls.add(this.getClass().getSimpleName() + ":before connect"); -// -// if (this.connection != null) { -// this.calls.add(this.getClass().getSimpleName() + ":connection"); -// return this.connection; -// } -// -// Connection result = connectFunc.call(); -// this.calls.add(this.getClass().getSimpleName() + ":after connect"); -// -// return result; -// } -// -// public Connection forceConnect( -// String driverProtocol, -// HostSpec hostSpec, -// Properties props, -// boolean isInitialConnection, -// JdbcCallable forceConnectFunc) -// throws SQLException { -// -// this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); -// -// if (this.connection != null) { -// this.calls.add(this.getClass().getSimpleName() + ":forced connection"); -// return this.connection; -// } -// -// Connection result = forceConnectFunc.call(); -// this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); -// -// return result; -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.mock; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Properties; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.util.connection.ConnectionContext; + +public class TestPluginThree extends TestPluginOne { + + private Connection connection; + + public TestPluginThree(ArrayList calls) { + super(); + this.calls = calls; + + this.subscribedMethods = new HashSet<>(Arrays.asList( + JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.CONNECT.methodName, JdbcMethod.FORCECONNECT.methodName)); + } + + public TestPluginThree(ArrayList calls, Connection connection) { + this(calls); + this.connection = connection; + } + + @Override + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { + + this.calls.add(this.getClass().getSimpleName() + ":before connect"); + + if (this.connection != null) { + this.calls.add(this.getClass().getSimpleName() + ":connection"); + return this.connection; + } + + Connection result = connectFunc.call(); + this.calls.add(this.getClass().getSimpleName() + ":after connect"); + + return result; + } + + public Connection forceConnect( + String driverProtocol, + HostSpec hostSpec, + Properties props, + boolean isInitialConnection, + JdbcCallable forceConnectFunc) + throws SQLException { + + this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); + + if (this.connection != null) { + this.calls.add(this.getClass().getSimpleName() + ":forced connection"); + return this.connection; + } + + Connection result = forceConnectFunc.call(); + this.calls.add(this.getClass().getSimpleName() + ":after forceConnect"); + + return result; + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java index 84bdfd414..91f7e6260 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java @@ -1,112 +1,112 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.mock; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.HashSet; -// import java.util.Properties; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.util.connection.ConnectionContext; -// -// public class TestPluginThrowException extends TestPluginOne { -// -// protected final Class exceptionClass; -// protected final boolean isBefore; -// -// public TestPluginThrowException( -// ArrayList calls, Class exceptionClass, boolean isBefore) { -// super(); -// this.calls = calls; -// this.exceptionClass = exceptionClass; -// this.isBefore = isBefore; -// -// this.subscribedMethods = new HashSet<>(Arrays.asList("*")); -// } -// -// @Override -// public T execute( -// Class resultClass, -// Class exceptionClass, -// Object methodInvokeOn, -// String methodName, -// JdbcCallable jdbcMethodFunc, -// Object[] jdbcMethodArgs) -// throws E { -// -// this.calls.add(this.getClass().getSimpleName() + ":before"); -// if (this.isBefore) { -// try { -// throw this.exceptionClass.newInstance(); -// } catch (Exception e) { -// throw new RuntimeException(e); -// } -// } -// -// T result = jdbcMethodFunc.call(); -// -// this.calls.add(this.getClass().getSimpleName() + ":after"); -// //noinspection ConstantConditions -// if (!this.isBefore) { -// try { -// throw this.exceptionClass.newInstance(); -// } catch (Exception e) { -// throw new RuntimeException(e); -// } -// } -// -// return result; -// } -// -// @Override -// public Connection connect( -// final ConnectionContext connectionContext, -// final HostSpec hostSpec, -// final boolean isInitialConnection, -// final JdbcCallable connectFunc) throws SQLException { -// -// this.calls.add(this.getClass().getSimpleName() + ":before"); -// if (this.isBefore) { -// throwException(); -// } -// -// Connection conn = connectFunc.call(); -// -// this.calls.add(this.getClass().getSimpleName() + ":after"); -// if (!this.isBefore) { -// throwException(); -// } -// -// return conn; -// } -// -// private void throwException() throws SQLException { -// try { -// throw this.exceptionClass.newInstance(); -// } catch (RuntimeException e) { -// throw e; -// } catch (Exception e) { -// if (e instanceof SQLException) { -// throw (SQLException) e; -// } -// throw new SQLException(e); -// } -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.mock; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Properties; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.util.connection.ConnectionContext; + +public class TestPluginThrowException extends TestPluginOne { + + protected final Class exceptionClass; + protected final boolean isBefore; + + public TestPluginThrowException( + ArrayList calls, Class exceptionClass, boolean isBefore) { + super(); + this.calls = calls; + this.exceptionClass = exceptionClass; + this.isBefore = isBefore; + + this.subscribedMethods = new HashSet<>(Arrays.asList("*")); + } + + @Override + public T execute( + Class resultClass, + Class exceptionClass, + Object methodInvokeOn, + String methodName, + JdbcCallable jdbcMethodFunc, + Object[] jdbcMethodArgs) + throws E { + + this.calls.add(this.getClass().getSimpleName() + ":before"); + if (this.isBefore) { + try { + throw this.exceptionClass.newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + T result = jdbcMethodFunc.call(); + + this.calls.add(this.getClass().getSimpleName() + ":after"); + //noinspection ConstantConditions + if (!this.isBefore) { + try { + throw this.exceptionClass.newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + return result; + } + + @Override + public Connection connect( + final ConnectionContext connectionContext, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable connectFunc) throws SQLException { + + this.calls.add(this.getClass().getSimpleName() + ":before"); + if (this.isBefore) { + throwException(); + } + + Connection conn = connectFunc.call(); + + this.calls.add(this.getClass().getSimpleName() + ":after"); + if (!this.isBefore) { + throwException(); + } + + return conn; + } + + private void throwException() throws SQLException { + try { + throw this.exceptionClass.newInstance(); + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + if (e instanceof SQLException) { + throw (SQLException) e; + } + throw new SQLException(e); + } + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java index 0a96cd136..82af4b437 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginTwo.java @@ -1,33 +1,33 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.mock; -// -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.HashSet; -// import software.amazon.jdbc.JdbcMethod; -// -// public class TestPluginTwo extends TestPluginOne { -// -// public TestPluginTwo(ArrayList calls) { -// super(); -// this.calls = calls; -// -// this.subscribedMethods = new HashSet<>( -// Arrays.asList(JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.BLOB_POSITION.methodName)); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.mock; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import software.amazon.jdbc.JdbcMethod; + +public class TestPluginTwo extends TestPluginOne { + + public TestPluginTwo(ArrayList calls) { + super(); + this.calls = calls; + + this.subscribedMethods = new HashSet<>( + Arrays.asList(JdbcMethod.BLOB_LENGTH.methodName, JdbcMethod.BLOB_POSITION.methodName)); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java index 8d00543a1..70d62ae9f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java @@ -1,245 +1,245 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doThrow; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.Collections; -// import java.util.HashSet; -// import java.util.Properties; -// import java.util.Set; -// import java.util.stream.Stream; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.Arguments; -// import org.junit.jupiter.params.provider.MethodSource; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.JdbcMethod; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.plugin.failover.FailoverSQLException; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.RdsUrlType; -// import software.amazon.jdbc.util.RdsUtils; -// -// public class AuroraConnectionTrackerPluginTest { -// -// public static final Properties EMPTY_PROPERTIES = new Properties(); -// @Mock Connection mockConnection; -// @Mock Statement mockStatement; -// @Mock ResultSet mockResultSet; -// @Mock PluginService mockPluginService; -// @Mock Dialect mockTopologyAwareDialect; -// @Mock RdsUtils mockRdsUtils; -// @Mock OpenedConnectionTracker mockTracker; -// @Mock JdbcCallable mockConnectionFunction; -// @Mock JdbcCallable mockSqlFunction; -// @Mock JdbcCallable mockCloseOrAbortFunction; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// -// private static final Object[] SQL_ARGS = {"sql"}; -// -// private AutoCloseable closeable; -// -// -// @BeforeEach -// void setUp() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// when(mockConnectionFunction.call()).thenReturn(mockConnection); -// when(mockSqlFunction.call()).thenReturn(mockResultSet); -// when(mockConnection.createStatement()).thenReturn(mockStatement); -// when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); -// when(mockRdsUtils.getRdsInstanceHostPattern(any(String.class))).thenReturn("?"); -// when(mockRdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); -// when(mockPluginService.getDialect()).thenReturn(mockTopologyAwareDialect); -// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); -// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// } -// -// @ParameterizedTest -// @MethodSource("trackNewConnectionsParameters") -// public void testTrackNewInstanceConnections( -// final String protocol, -// final boolean isInitialConnection) throws SQLException { -// final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance1") -// .build(); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(hostSpec); -// when(mockRdsUtils.isRdsInstance("instance1")).thenReturn(true); -// -// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( -// mockPluginService, -// EMPTY_PROPERTIES, -// mockRdsUtils, -// mockTracker); -// -// final Connection actualConnection = plugin.connect( -// protocol, -// hostSpec, -// EMPTY_PROPERTIES, -// isInitialConnection, -// mockConnectionFunction); -// -// assertEquals(mockConnection, actualConnection); -// verify(mockTracker).populateOpenedConnectionQueue(eq(hostSpec), eq(mockConnection)); -// final Set aliases = hostSpec.getAliases(); -// assertEquals(0, aliases.size()); -// } -// -// @Test -// public void testInvalidateOpenedConnectionsWhenWriterHostNotChange() throws SQLException { -// final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); -// final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("host") -// .role(HostRole.WRITER) -// .build(); -// final HostSpec newHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("new-host") -// .role(HostRole.WRITER) -// .build(); -// -// // Host list changes during simulated failover -// when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList(originalHost)); -// doThrow(expectedException).when(mockSqlFunction).call(); -// -// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( -// mockPluginService, -// EMPTY_PROPERTIES, -// mockRdsUtils, -// mockTracker); -// -// final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( -// ResultSet.class, -// SQLException.class, -// Statement.class, -// "Statement.executeQuery", -// mockSqlFunction, -// SQL_ARGS -// )); -// -// assertEquals(expectedException, exception); -// verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); -// verify(mockTracker, never()).invalidateAllConnections(originalHost); -// } -// -// @Test -// public void testInvalidateOpenedConnectionsWhenWriterHostChanged() throws SQLException { -// final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); -// final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") -// .build(); -// final HostSpec failoverTargetHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2") -// .build(); -// when(mockPluginService.getAllHosts()) -// .thenReturn(Collections.singletonList(originalHost)) -// .thenReturn(Collections.singletonList(failoverTargetHost)); -// when(mockSqlFunction.call()) -// .thenReturn(mockResultSet) -// .thenThrow(expectedException); -// -// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( -// mockPluginService, -// EMPTY_PROPERTIES, -// mockRdsUtils, -// mockTracker); -// -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// Statement.class, -// "Statement.executeQuery", -// mockSqlFunction, -// SQL_ARGS -// ); -// -// final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( -// ResultSet.class, -// SQLException.class, -// Statement.class, -// "Statement.executeQuery", -// mockSqlFunction, -// SQL_ARGS -// )); -// assertEquals(expectedException, exception); -// verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); -// verify(mockTracker).invalidateAllConnections(originalHost); -// } -// -// @ParameterizedTest -// @MethodSource("testInvalidateConnectionsOnCloseOrAbortArgs") -// public void testInvalidateConnectionsOnCloseOrAbort(final String method) throws SQLException { -// final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") -// .build(); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(originalHost); -// -// final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( -// mockPluginService, -// EMPTY_PROPERTIES, -// mockRdsUtils, -// mockTracker); -// -// plugin.execute( -// Void.class, -// SQLException.class, -// Connection.class, -// method, -// mockCloseOrAbortFunction, -// SQL_ARGS -// ); -// -// verify(mockTracker).removeConnectionTracking(eq(originalHost), eq(mockConnection)); -// } -// -// static Stream testInvalidateConnectionsOnCloseOrAbortArgs() { -// return Stream.of( -// Arguments.of(JdbcMethod.CONNECTION_ABORT.methodName), -// Arguments.of(JdbcMethod.CONNECTION_CLOSE.methodName) -// ); -// } -// -// private static Stream trackNewConnectionsParameters() { -// return Stream.of( -// Arguments.of("postgresql", true), -// Arguments.of("postgresql", false), -// Arguments.of("otherProtocol", true), -// Arguments.of("otherProtocol", false) -// ); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.plugin.failover.FailoverSQLException; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.RdsUrlType; +import software.amazon.jdbc.util.RdsUtils; + +public class AuroraConnectionTrackerPluginTest { + + public static final Properties EMPTY_PROPERTIES = new Properties(); + @Mock Connection mockConnection; + @Mock Statement mockStatement; + @Mock ResultSet mockResultSet; + @Mock PluginService mockPluginService; + @Mock Dialect mockTopologyAwareDialect; + @Mock RdsUtils mockRdsUtils; + @Mock OpenedConnectionTracker mockTracker; + @Mock JdbcCallable mockConnectionFunction; + @Mock JdbcCallable mockSqlFunction; + @Mock JdbcCallable mockCloseOrAbortFunction; + @Mock TargetDriverDialect mockTargetDriverDialect; + + private static final Object[] SQL_ARGS = {"sql"}; + + private AutoCloseable closeable; + + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + when(mockConnectionFunction.call()).thenReturn(mockConnection); + when(mockSqlFunction.call()).thenReturn(mockResultSet); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); + when(mockRdsUtils.getRdsInstanceHostPattern(any(String.class))).thenReturn("?"); + when(mockRdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.getDialect()).thenReturn(mockTopologyAwareDialect); + when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); + when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @ParameterizedTest + @MethodSource("trackNewConnectionsParameters") + public void testTrackNewInstanceConnections( + final String protocol, + final boolean isInitialConnection) throws SQLException { + final HostSpec hostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("instance1") + .build(); + when(mockPluginService.getCurrentHostSpec()).thenReturn(hostSpec); + when(mockRdsUtils.isRdsInstance("instance1")).thenReturn(true); + + final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( + mockPluginService, + EMPTY_PROPERTIES, + mockRdsUtils, + mockTracker); + + final Connection actualConnection = plugin.connect( + protocol, + hostSpec, + EMPTY_PROPERTIES, + isInitialConnection, + mockConnectionFunction); + + assertEquals(mockConnection, actualConnection); + verify(mockTracker).populateOpenedConnectionQueue(eq(hostSpec), eq(mockConnection)); + final Set aliases = hostSpec.getAliases(); + assertEquals(0, aliases.size()); + } + + @Test + public void testInvalidateOpenedConnectionsWhenWriterHostNotChange() throws SQLException { + final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); + final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("host") + .role(HostRole.WRITER) + .build(); + final HostSpec newHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("new-host") + .role(HostRole.WRITER) + .build(); + + // Host list changes during simulated failover + when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList(originalHost)); + doThrow(expectedException).when(mockSqlFunction).call(); + + final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( + mockPluginService, + EMPTY_PROPERTIES, + mockRdsUtils, + mockTracker); + + final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( + ResultSet.class, + SQLException.class, + Statement.class, + "Statement.executeQuery", + mockSqlFunction, + SQL_ARGS + )); + + assertEquals(expectedException, exception); + verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); + verify(mockTracker, never()).invalidateAllConnections(originalHost); + } + + @Test + public void testInvalidateOpenedConnectionsWhenWriterHostChanged() throws SQLException { + final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); + final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") + .build(); + final HostSpec failoverTargetHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2") + .build(); + when(mockPluginService.getAllHosts()) + .thenReturn(Collections.singletonList(originalHost)) + .thenReturn(Collections.singletonList(failoverTargetHost)); + when(mockSqlFunction.call()) + .thenReturn(mockResultSet) + .thenThrow(expectedException); + + final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( + mockPluginService, + EMPTY_PROPERTIES, + mockRdsUtils, + mockTracker); + + plugin.execute( + ResultSet.class, + SQLException.class, + Statement.class, + "Statement.executeQuery", + mockSqlFunction, + SQL_ARGS + ); + + final SQLException exception = assertThrows(FailoverSQLException.class, () -> plugin.execute( + ResultSet.class, + SQLException.class, + Statement.class, + "Statement.executeQuery", + mockSqlFunction, + SQL_ARGS + )); + assertEquals(expectedException, exception); + verify(mockTracker, never()).removeConnectionTracking(eq(originalHost), eq(mockConnection)); + verify(mockTracker).invalidateAllConnections(originalHost); + } + + @ParameterizedTest + @MethodSource("testInvalidateConnectionsOnCloseOrAbortArgs") + public void testInvalidateConnectionsOnCloseOrAbort(final String method) throws SQLException { + final HostSpec originalHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host") + .build(); + when(mockPluginService.getCurrentHostSpec()).thenReturn(originalHost); + + final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( + mockPluginService, + EMPTY_PROPERTIES, + mockRdsUtils, + mockTracker); + + plugin.execute( + Void.class, + SQLException.class, + Connection.class, + method, + mockCloseOrAbortFunction, + SQL_ARGS + ); + + verify(mockTracker).removeConnectionTracking(eq(originalHost), eq(mockConnection)); + } + + static Stream testInvalidateConnectionsOnCloseOrAbortArgs() { + return Stream.of( + Arguments.of(JdbcMethod.CONNECTION_ABORT.methodName), + Arguments.of(JdbcMethod.CONNECTION_CLOSE.methodName) + ); + } + + private static Stream trackNewConnectionsParameters() { + return Stream.of( + Arguments.of("postgresql", true), + Arguments.of("postgresql", false), + Arguments.of("otherProtocol", true), + Arguments.of("otherProtocol", false) + ); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java index e7517c933..22a8339c1 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java @@ -1,516 +1,516 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertNotEquals; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doThrow; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.REGION_PROPERTY; -// import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.SECRET_ID_PROPERTY; -// -// import com.mysql.cj.exceptions.CJException; -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.Properties; -// import java.util.stream.Stream; -// import org.jetbrains.annotations.NotNull; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.Arguments; -// import org.junit.jupiter.params.provider.MethodSource; -// import org.junit.jupiter.params.provider.ValueSource; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import org.postgresql.util.PSQLException; -// import org.postgresql.util.PSQLState; -// import software.amazon.awssdk.regions.Region; -// import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; -// import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; -// import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; -// import software.amazon.awssdk.services.secretsmanager.model.SecretsManagerException; -// import software.amazon.jdbc.ConnectionPluginManager; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.PluginServiceImpl; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.dialect.DialectManager; -// import software.amazon.jdbc.exceptions.ExceptionHandler; -// import software.amazon.jdbc.exceptions.ExceptionManager; -// import software.amazon.jdbc.exceptions.MySQLExceptionHandler; -// import software.amazon.jdbc.exceptions.PgExceptionHandler; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.profile.ConfigurationProfile; -// import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -// import software.amazon.jdbc.states.SessionStateService; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.Messages; -// import software.amazon.jdbc.util.Pair; -// import software.amazon.jdbc.util.telemetry.GaugeCallable; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.util.telemetry.TelemetryGauge; -// -// @SuppressWarnings("resource") -// public class AwsSecretsManagerConnectionPluginTest { -// -// private static final String TEST_PG_PROTOCOL = "jdbc:aws-wrapper:postgresql:"; -// private static final String TEST_MYSQL_PROTOCOL = "jdbc:aws-wrapper:mysql:"; -// private static final String TEST_REGION = "us-east-2"; -// private static final String TEST_SECRET_ID = "secretId"; -// private static final String TEST_USERNAME = "testUser"; -// private static final String TEST_PASSWORD = "testPassword"; -// private static final String VALID_SECRET_STRING = -// "{\"username\": \"" + TEST_USERNAME + "\", \"password\": \"" + TEST_PASSWORD + "\"}"; -// private static final String INVALID_SECRET_STRING = "{username: invalid, password: invalid}"; -// private static final String TEST_HOST = "test-domain"; -// private static final String TEST_SQL_ERROR = "SQL exception error message"; -// private static final String UNHANDLED_ERROR_CODE = "HY000"; -// private static final int TEST_PORT = 5432; -// private static final Pair SECRET_CACHE_KEY = Pair.create(TEST_SECRET_ID, TEST_REGION); -// private static final AwsSecretsManagerConnectionPlugin.Secret TEST_SECRET = -// new AwsSecretsManagerConnectionPlugin.Secret("testUser", "testPassword"); -// private static final HostSpec TEST_HOSTSPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host(TEST_HOST).port(TEST_PORT).build(); -// private static final GetSecretValueResponse VALID_GET_SECRET_VALUE_RESPONSE = -// GetSecretValueResponse.builder().secretString(VALID_SECRET_STRING).build(); -// private static final GetSecretValueResponse INVALID_GET_SECRET_VALUE_RESPONSE = -// GetSecretValueResponse.builder().secretString(INVALID_SECRET_STRING).build(); -// private static final Properties TEST_PROPS = new Properties(); -// private AwsSecretsManagerConnectionPlugin plugin; -// -// private AutoCloseable closeable; -// -// @Mock FullServicesContainer mockServicesContainer; -// @Mock SecretsManagerClient mockSecretsManagerClient; -// @Mock GetSecretValueRequest mockGetValueRequest; -// @Mock JdbcCallable connectFunc; -// @Mock PluginServiceImpl mockService; -// @Mock ConnectionPluginManager mockConnectionPluginManager; -// @Mock Dialect mockTopologyAwareDialect; -// @Mock DialectManager mockDialectManager; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock TelemetryCounter mockTelemetryCounter; -// @Mock TelemetryGauge mockTelemetryGauge; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); -// -// @Mock SessionStateService mockSessionStateService; -// -// @BeforeEach -// public void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// -// REGION_PROPERTY.set(TEST_PROPS, TEST_REGION); -// SECRET_ID_PROPERTY.set(TEST_PROPS, TEST_SECRET_ID); -// -// when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) -// .thenReturn(mockTopologyAwareDialect); -// -// when(mockServicesContainer.getConnectionPluginManager()).thenReturn(mockConnectionPluginManager); -// when(mockService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); -// // noinspection unchecked -// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); -// -// this.plugin = new AwsSecretsManagerConnectionPlugin( -// mockService, -// TEST_PROPS, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest); -// -// when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) -// .thenReturn(mockTopologyAwareDialect); -// -// when(mockService.getHostSpecBuilder()).thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// AwsSecretsManagerCacheHolder.clearCache(); -// TEST_PROPS.clear(); -// } -// -// /** -// * The plugin will successfully open a connection with a cached secret. -// */ -// @Test -// public void testConnectWithCachedSecrets() throws SQLException { -// // Add initial cached secret to be used for a connection. -// AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); -// -// this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); -// -// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); -// verify(this.connectFunc).call(); -// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); -// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); -// } -// -// /** -// * The plugin will attempt to open a connection with an empty secret cache. The plugin will fetch the secret from the -// * AWS Secrets Manager. -// */ -// @Test -// public void testConnectWithNewSecrets() throws SQLException { -// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) -// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); -// -// this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); -// -// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); -// verify(connectFunc).call(); -// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); -// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); -// } -// -// @ParameterizedTest -// @MethodSource("missingArguments") -// public void testMissingRequiredParameters(final Properties properties) { -// assertThrows(RuntimeException.class, () -> new AwsSecretsManagerConnectionPlugin( -// mockService, -// properties, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest)); -// } -// -// /** -// * The plugin will attempt to open a connection with a cached secret, but it will fail with a generic SQL exception. -// * In this case, the plugin will rethrow the error back to the user. -// */ -// @Test -// public void testFailedInitialConnectionWithUnhandledError() throws SQLException { -// AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); -// final SQLException failedFirstConnectionGenericException = new SQLException(TEST_SQL_ERROR, UNHANDLED_ERROR_CODE); -// doThrow(failedFirstConnectionGenericException).when(connectFunc).call(); -// -// final SQLException connectionFailedException = assertThrows( -// SQLException.class, -// () -> this.plugin.connect( -// TEST_PG_PROTOCOL, -// TEST_HOSTSPEC, -// TEST_PROPS, -// true, -// this.connectFunc)); -// -// assertEquals(TEST_SQL_ERROR, connectionFailedException.getMessage()); -// verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); -// verify(connectFunc).call(); -// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); -// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); -// } -// -// /** -// * The plugin will attempt to open a connection with a cached secret, but it will fail with an access error. In this -// * case, the plugin will fetch the secret and will retry the connection. -// */ -// @ParameterizedTest -// @MethodSource("provideExceptionCodeForDifferentDrivers") -// public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( -// String accessError, -// String protocol, -// ExceptionHandler exceptionHandler) throws SQLException { -// this.plugin = new AwsSecretsManagerConnectionPlugin( -// getPluginService(protocol), -// TEST_PROPS, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest); -// -// // Fail the initial connection attempt with cached secret. -// // Second attempt should be successful. -// AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); -// final SQLException failedFirstConnectionAccessException = new SQLException(TEST_SQL_ERROR, -// accessError); -// doThrow(failedFirstConnectionAccessException).when(connectFunc).call(); -// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) -// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); -// -// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(exceptionHandler); -// -// assertThrows( -// SQLException.class, -// () -> this.plugin.connect( -// TEST_PG_PROTOCOL, -// TEST_HOSTSPEC, -// TEST_PROPS, -// true, -// this.connectFunc)); -// -// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); -// verify(connectFunc, times(2)).call(); -// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); -// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); -// } -// -// private @NotNull PluginServiceImpl getPluginService(String protocol) throws SQLException { -// return new PluginServiceImpl( -// mockServicesContainer, -// new ExceptionManager(), -// TEST_PROPS, -// "url", -// protocol, -// mockDialectManager, -// mockTargetDriverDialect, -// configurationProfile, -// mockSessionStateService); -// } -// -// /** -// * The plugin will attempt to open a connection after fetching a secret, but it will fail because the returned secret -// * could not be parsed. -// */ -// @Test -// public void testFailedToReadSecrets() throws SQLException { -// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) -// .thenReturn(INVALID_GET_SECRET_VALUE_RESPONSE); -// -// final SQLException readSecretsFailedException = -// assertThrows( -// SQLException.class, -// () -> this.plugin.connect( -// TEST_PG_PROTOCOL, -// TEST_HOSTSPEC, -// TEST_PROPS, -// true, -// this.connectFunc)); -// -// assertEquals( -// readSecretsFailedException.getMessage(), -// Messages.get( -// "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); -// assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); -// verify(this.connectFunc, never()).call(); -// } -// -// /** -// * The plugin will attempt to open a connection after fetching a secret, but it will fail because an exception was -// * thrown by the AWS Secrets Manager. -// */ -// @Test -// public void testFailedToGetSecrets() throws SQLException { -// doThrow(SecretsManagerException.class).when(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); -// -// final SQLException getSecretsFailedException = -// assertThrows( -// SQLException.class, -// () -> this.plugin.connect( -// TEST_PG_PROTOCOL, -// TEST_HOSTSPEC, -// TEST_PROPS, -// true, -// this.connectFunc)); -// -// assertEquals( -// getSecretsFailedException.getMessage(), -// Messages.get( -// "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); -// assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); -// verify(this.connectFunc, never()).call(); -// } -// -// @ParameterizedTest -// @ValueSource(strings = {"28000", "28P01"}) -// public void testFailedInitialConnectionWithWrappedGenericError(final String accessError) throws SQLException { -// this.plugin = new AwsSecretsManagerConnectionPlugin( -// getPluginService(TEST_PG_PROTOCOL), -// TEST_PROPS, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest); -// -// // Fail the initial connection attempt with a wrapped exception. -// // Second attempt should be successful. -// final SQLException targetException = new SQLException(TEST_SQL_ERROR, accessError); -// final SQLException wrappedException = new SQLException(targetException); -// doThrow(wrappedException).when(connectFunc).call(); -// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) -// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); -// -// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); -// -// assertThrows( -// SQLException.class, -// () -> this.plugin.connect( -// TEST_PG_PROTOCOL, -// TEST_HOSTSPEC, -// TEST_PROPS, -// true, -// this.connectFunc)); -// -// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(connectFunc).call(); -// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); -// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); -// } -// -// @Test -// public void testConnectWithWrappedMySQLException() throws SQLException { -// this.plugin = new AwsSecretsManagerConnectionPlugin( -// getPluginService(TEST_MYSQL_PROTOCOL), -// TEST_PROPS, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest); -// -// final CJException targetException = new CJException("28000"); -// final SQLException wrappedException = new SQLException(targetException); -// -// doThrow(wrappedException).when(connectFunc).call(); -// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) -// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); -// -// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); -// -// assertThrows( -// SQLException.class, -// () -> this.plugin.connect( -// TEST_MYSQL_PROTOCOL, -// TEST_HOSTSPEC, -// TEST_PROPS, -// true, -// this.connectFunc)); -// -// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(connectFunc).call(); -// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); -// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); -// } -// -// @Test -// public void testConnectWithWrappedPostgreSQLException() throws SQLException { -// this.plugin = new AwsSecretsManagerConnectionPlugin( -// getPluginService(TEST_PG_PROTOCOL), -// TEST_PROPS, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest); -// -// final PSQLException targetException = new PSQLException("login error", PSQLState.INVALID_PASSWORD, null); -// final SQLException wrappedException = new SQLException(targetException); -// -// doThrow(wrappedException).when(connectFunc).call(); -// when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) -// .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); -// -// when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); -// -// assertThrows( -// SQLException.class, -// () -> this.plugin.connect( -// TEST_PG_PROTOCOL, -// TEST_HOSTSPEC, -// TEST_PROPS, -// true, -// this.connectFunc)); -// -// assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); -// verify(connectFunc).call(); -// assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); -// assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); -// } -// -// @ParameterizedTest -// @MethodSource("arnArguments") -// public void testConnectViaARN(final String arn, final Region expectedRegionParsedFromARN) -// throws SQLException { -// final Properties props = new Properties(); -// -// SECRET_ID_PROPERTY.set(props, arn); -// -// this.plugin = spy(new AwsSecretsManagerConnectionPlugin( -// new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), -// props, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest)); -// -// final Pair secret = this.plugin.secretKey; -// assertEquals(expectedRegionParsedFromARN, Region.of(secret.getValue2())); -// } -// -// @ParameterizedTest -// @MethodSource("arnArguments") -// public void testConnectionWithRegionParameterAndARN(final String arn, final Region regionParsedFromARN) -// throws SQLException { -// final Region expectedRegion = Region.US_ISO_EAST_1; -// -// final Properties props = new Properties(); -// SECRET_ID_PROPERTY.set(props, arn); -// REGION_PROPERTY.set(props, expectedRegion.toString()); -// -// this.plugin = spy(new AwsSecretsManagerConnectionPlugin( -// new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), -// props, -// (host, r) -> mockSecretsManagerClient, -// (id) -> mockGetValueRequest)); -// -// final Pair secret = this.plugin.secretKey; -// // The region specified in `secretsManagerRegion` should override the region parsed from ARN. -// assertNotEquals(regionParsedFromARN, Region.of(secret.getValue2())); -// assertEquals(expectedRegion, Region.of(secret.getValue2())); -// } -// -// private static Stream provideExceptionCodeForDifferentDrivers() { -// return Stream.of( -// Arguments.of("28000", TEST_MYSQL_PROTOCOL, new MySQLExceptionHandler()), -// Arguments.of("28P01", TEST_PG_PROTOCOL, new PgExceptionHandler()) -// ); -// } -// -// private static Stream arnArguments() { -// return Stream.of( -// Arguments.of("arn:aws:secretsmanager:us-east-2:123456789012:secret:foo", Region.US_EAST_2), -// Arguments.of("arn:aws:secretsmanager:us-west-1:123456789012:secret:boo", Region.US_WEST_1), -// Arguments.of( -// "arn:aws:secretsmanager:us-east-2:123456789012:secret:rds!cluster-bar-foo", -// Region.US_EAST_2) -// ); -// } -// -// private static Stream missingArguments() { -// final Properties missingId = new Properties(); -// REGION_PROPERTY.set(missingId, TEST_REGION); -// -// final Properties missingRegion = new Properties(); -// SECRET_ID_PROPERTY.set(missingRegion, TEST_SECRET_ID); -// -// return Stream.of( -// Arguments.of(missingId), -// Arguments.of(missingRegion) -// ); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.REGION_PROPERTY; +import static software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin.SECRET_ID_PROPERTY; + +import com.mysql.cj.exceptions.CJException; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Properties; +import java.util.stream.Stream; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.postgresql.util.PSQLException; +import org.postgresql.util.PSQLState; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; +import software.amazon.awssdk.services.secretsmanager.model.SecretsManagerException; +import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginServiceImpl; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.dialect.DialectManager; +import software.amazon.jdbc.exceptions.ExceptionHandler; +import software.amazon.jdbc.exceptions.ExceptionManager; +import software.amazon.jdbc.exceptions.MySQLExceptionHandler; +import software.amazon.jdbc.exceptions.PgExceptionHandler; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.profile.ConfigurationProfile; +import software.amazon.jdbc.profile.ConfigurationProfileBuilder; +import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.Pair; +import software.amazon.jdbc.util.telemetry.GaugeCallable; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +@SuppressWarnings("resource") +public class AwsSecretsManagerConnectionPluginTest { + + private static final String TEST_PG_PROTOCOL = "jdbc:aws-wrapper:postgresql:"; + private static final String TEST_MYSQL_PROTOCOL = "jdbc:aws-wrapper:mysql:"; + private static final String TEST_REGION = "us-east-2"; + private static final String TEST_SECRET_ID = "secretId"; + private static final String TEST_USERNAME = "testUser"; + private static final String TEST_PASSWORD = "testPassword"; + private static final String VALID_SECRET_STRING = + "{\"username\": \"" + TEST_USERNAME + "\", \"password\": \"" + TEST_PASSWORD + "\"}"; + private static final String INVALID_SECRET_STRING = "{username: invalid, password: invalid}"; + private static final String TEST_HOST = "test-domain"; + private static final String TEST_SQL_ERROR = "SQL exception error message"; + private static final String UNHANDLED_ERROR_CODE = "HY000"; + private static final int TEST_PORT = 5432; + private static final Pair SECRET_CACHE_KEY = Pair.create(TEST_SECRET_ID, TEST_REGION); + private static final AwsSecretsManagerConnectionPlugin.Secret TEST_SECRET = + new AwsSecretsManagerConnectionPlugin.Secret("testUser", "testPassword"); + private static final HostSpec TEST_HOSTSPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host(TEST_HOST).port(TEST_PORT).build(); + private static final GetSecretValueResponse VALID_GET_SECRET_VALUE_RESPONSE = + GetSecretValueResponse.builder().secretString(VALID_SECRET_STRING).build(); + private static final GetSecretValueResponse INVALID_GET_SECRET_VALUE_RESPONSE = + GetSecretValueResponse.builder().secretString(INVALID_SECRET_STRING).build(); + private static final Properties TEST_PROPS = new Properties(); + private AwsSecretsManagerConnectionPlugin plugin; + + private AutoCloseable closeable; + + @Mock FullServicesContainer mockServicesContainer; + @Mock SecretsManagerClient mockSecretsManagerClient; + @Mock GetSecretValueRequest mockGetValueRequest; + @Mock JdbcCallable connectFunc; + @Mock PluginServiceImpl mockService; + @Mock ConnectionPluginManager mockConnectionPluginManager; + @Mock Dialect mockTopologyAwareDialect; + @Mock DialectManager mockDialectManager; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock TelemetryContext mockTelemetryContext; + @Mock TelemetryCounter mockTelemetryCounter; + @Mock TelemetryGauge mockTelemetryGauge; + @Mock TargetDriverDialect mockTargetDriverDialect; + ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); + + @Mock SessionStateService mockSessionStateService; + + @BeforeEach + public void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + + REGION_PROPERTY.set(TEST_PROPS, TEST_REGION); + SECRET_ID_PROPERTY.set(TEST_PROPS, TEST_SECRET_ID); + + when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) + .thenReturn(mockTopologyAwareDialect); + + when(mockServicesContainer.getConnectionPluginManager()).thenReturn(mockConnectionPluginManager); + when(mockService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); + // noinspection unchecked + when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); + + this.plugin = new AwsSecretsManagerConnectionPlugin( + mockService, + TEST_PROPS, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest); + + when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) + .thenReturn(mockTopologyAwareDialect); + + when(mockService.getHostSpecBuilder()).thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + AwsSecretsManagerCacheHolder.clearCache(); + TEST_PROPS.clear(); + } + + /** + * The plugin will successfully open a connection with a cached secret. + */ + @Test + public void testConnectWithCachedSecrets() throws SQLException { + // Add initial cached secret to be used for a connection. + AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); + + this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); + + assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); + verify(this.connectFunc).call(); + assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); + assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); + } + + /** + * The plugin will attempt to open a connection with an empty secret cache. The plugin will fetch the secret from the + * AWS Secrets Manager. + */ + @Test + public void testConnectWithNewSecrets() throws SQLException { + when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) + .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); + + this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); + + assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); + verify(connectFunc).call(); + assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); + assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); + } + + @ParameterizedTest + @MethodSource("missingArguments") + public void testMissingRequiredParameters(final Properties properties) { + assertThrows(RuntimeException.class, () -> new AwsSecretsManagerConnectionPlugin( + mockService, + properties, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest)); + } + + /** + * The plugin will attempt to open a connection with a cached secret, but it will fail with a generic SQL exception. + * In this case, the plugin will rethrow the error back to the user. + */ + @Test + public void testFailedInitialConnectionWithUnhandledError() throws SQLException { + AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); + final SQLException failedFirstConnectionGenericException = new SQLException(TEST_SQL_ERROR, UNHANDLED_ERROR_CODE); + doThrow(failedFirstConnectionGenericException).when(connectFunc).call(); + + final SQLException connectionFailedException = assertThrows( + SQLException.class, + () -> this.plugin.connect( + TEST_PG_PROTOCOL, + TEST_HOSTSPEC, + TEST_PROPS, + true, + this.connectFunc)); + + assertEquals(TEST_SQL_ERROR, connectionFailedException.getMessage()); + verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); + verify(connectFunc).call(); + assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); + assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); + } + + /** + * The plugin will attempt to open a connection with a cached secret, but it will fail with an access error. In this + * case, the plugin will fetch the secret and will retry the connection. + */ + @ParameterizedTest + @MethodSource("provideExceptionCodeForDifferentDrivers") + public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( + String accessError, + String protocol, + ExceptionHandler exceptionHandler) throws SQLException { + this.plugin = new AwsSecretsManagerConnectionPlugin( + getPluginService(protocol), + TEST_PROPS, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest); + + // Fail the initial connection attempt with cached secret. + // Second attempt should be successful. + AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); + final SQLException failedFirstConnectionAccessException = new SQLException(TEST_SQL_ERROR, + accessError); + doThrow(failedFirstConnectionAccessException).when(connectFunc).call(); + when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) + .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); + + when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(exceptionHandler); + + assertThrows( + SQLException.class, + () -> this.plugin.connect( + TEST_PG_PROTOCOL, + TEST_HOSTSPEC, + TEST_PROPS, + true, + this.connectFunc)); + + assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); + verify(connectFunc, times(2)).call(); + assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); + assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); + } + + private @NotNull PluginServiceImpl getPluginService(String protocol) throws SQLException { + return new PluginServiceImpl( + mockServicesContainer, + new ExceptionManager(), + TEST_PROPS, + "url", + protocol, + mockDialectManager, + mockTargetDriverDialect, + configurationProfile, + mockSessionStateService); + } + + /** + * The plugin will attempt to open a connection after fetching a secret, but it will fail because the returned secret + * could not be parsed. + */ + @Test + public void testFailedToReadSecrets() throws SQLException { + when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) + .thenReturn(INVALID_GET_SECRET_VALUE_RESPONSE); + + final SQLException readSecretsFailedException = + assertThrows( + SQLException.class, + () -> this.plugin.connect( + TEST_PG_PROTOCOL, + TEST_HOSTSPEC, + TEST_PROPS, + true, + this.connectFunc)); + + assertEquals( + readSecretsFailedException.getMessage(), + Messages.get( + "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); + assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); + verify(this.connectFunc, never()).call(); + } + + /** + * The plugin will attempt to open a connection after fetching a secret, but it will fail because an exception was + * thrown by the AWS Secrets Manager. + */ + @Test + public void testFailedToGetSecrets() throws SQLException { + doThrow(SecretsManagerException.class).when(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); + + final SQLException getSecretsFailedException = + assertThrows( + SQLException.class, + () -> this.plugin.connect( + TEST_PG_PROTOCOL, + TEST_HOSTSPEC, + TEST_PROPS, + true, + this.connectFunc)); + + assertEquals( + getSecretsFailedException.getMessage(), + Messages.get( + "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials")); + assertEquals(0, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); + verify(this.connectFunc, never()).call(); + } + + @ParameterizedTest + @ValueSource(strings = {"28000", "28P01"}) + public void testFailedInitialConnectionWithWrappedGenericError(final String accessError) throws SQLException { + this.plugin = new AwsSecretsManagerConnectionPlugin( + getPluginService(TEST_PG_PROTOCOL), + TEST_PROPS, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest); + + // Fail the initial connection attempt with a wrapped exception. + // Second attempt should be successful. + final SQLException targetException = new SQLException(TEST_SQL_ERROR, accessError); + final SQLException wrappedException = new SQLException(targetException); + doThrow(wrappedException).when(connectFunc).call(); + when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) + .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); + + when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); + + assertThrows( + SQLException.class, + () -> this.plugin.connect( + TEST_PG_PROTOCOL, + TEST_HOSTSPEC, + TEST_PROPS, + true, + this.connectFunc)); + + assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(connectFunc).call(); + assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); + assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); + } + + @Test + public void testConnectWithWrappedMySQLException() throws SQLException { + this.plugin = new AwsSecretsManagerConnectionPlugin( + getPluginService(TEST_MYSQL_PROTOCOL), + TEST_PROPS, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest); + + final CJException targetException = new CJException("28000"); + final SQLException wrappedException = new SQLException(targetException); + + doThrow(wrappedException).when(connectFunc).call(); + when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) + .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); + + when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); + + assertThrows( + SQLException.class, + () -> this.plugin.connect( + TEST_MYSQL_PROTOCOL, + TEST_HOSTSPEC, + TEST_PROPS, + true, + this.connectFunc)); + + assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(connectFunc).call(); + assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); + assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); + } + + @Test + public void testConnectWithWrappedPostgreSQLException() throws SQLException { + this.plugin = new AwsSecretsManagerConnectionPlugin( + getPluginService(TEST_PG_PROTOCOL), + TEST_PROPS, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest); + + final PSQLException targetException = new PSQLException("login error", PSQLState.INVALID_PASSWORD, null); + final SQLException wrappedException = new SQLException(targetException); + + doThrow(wrappedException).when(connectFunc).call(); + when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) + .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); + + when(mockTopologyAwareDialect.getExceptionHandler()).thenReturn(new PgExceptionHandler()); + + assertThrows( + SQLException.class, + () -> this.plugin.connect( + TEST_PG_PROTOCOL, + TEST_HOSTSPEC, + TEST_PROPS, + true, + this.connectFunc)); + + assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); + verify(connectFunc).call(); + assertEquals(TEST_USERNAME, TEST_PROPS.get(PropertyDefinition.USER.name)); + assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); + } + + @ParameterizedTest + @MethodSource("arnArguments") + public void testConnectViaARN(final String arn, final Region expectedRegionParsedFromARN) + throws SQLException { + final Properties props = new Properties(); + + SECRET_ID_PROPERTY.set(props, arn); + + this.plugin = spy(new AwsSecretsManagerConnectionPlugin( + new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), + props, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest)); + + final Pair secret = this.plugin.secretKey; + assertEquals(expectedRegionParsedFromARN, Region.of(secret.getValue2())); + } + + @ParameterizedTest + @MethodSource("arnArguments") + public void testConnectionWithRegionParameterAndARN(final String arn, final Region regionParsedFromARN) + throws SQLException { + final Region expectedRegion = Region.US_ISO_EAST_1; + + final Properties props = new Properties(); + SECRET_ID_PROPERTY.set(props, arn); + REGION_PROPERTY.set(props, expectedRegion.toString()); + + this.plugin = spy(new AwsSecretsManagerConnectionPlugin( + new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), + props, + (host, r) -> mockSecretsManagerClient, + (id) -> mockGetValueRequest)); + + final Pair secret = this.plugin.secretKey; + // The region specified in `secretsManagerRegion` should override the region parsed from ARN. + assertNotEquals(regionParsedFromARN, Region.of(secret.getValue2())); + assertEquals(expectedRegion, Region.of(secret.getValue2())); + } + + private static Stream provideExceptionCodeForDifferentDrivers() { + return Stream.of( + Arguments.of("28000", TEST_MYSQL_PROTOCOL, new MySQLExceptionHandler()), + Arguments.of("28P01", TEST_PG_PROTOCOL, new PgExceptionHandler()) + ); + } + + private static Stream arnArguments() { + return Stream.of( + Arguments.of("arn:aws:secretsmanager:us-east-2:123456789012:secret:foo", Region.US_EAST_2), + Arguments.of("arn:aws:secretsmanager:us-west-1:123456789012:secret:boo", Region.US_WEST_1), + Arguments.of( + "arn:aws:secretsmanager:us-east-2:123456789012:secret:rds!cluster-bar-foo", + Region.US_EAST_2) + ); + } + + private static Stream missingArguments() { + final Properties missingId = new Properties(); + REGION_PROPERTY.set(missingId, TEST_REGION); + + final Properties missingRegion = new Properties(); + SECRET_ID_PROPERTY.set(missingRegion, TEST_SECRET_ID); + + return Stream.of( + Arguments.of(missingId), + Arguments.of(missingRegion) + ); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java index d44e081b0..d8fac47be 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java @@ -1,138 +1,138 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyBoolean; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.atLeastOnce; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.ArrayList; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.List; -// import java.util.Properties; -// import java.util.stream.Stream; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.Arguments; -// import org.junit.jupiter.params.provider.MethodSource; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.ConnectionProvider; -// import software.amazon.jdbc.ConnectionProviderManager; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.PluginManagerService; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.util.telemetry.GaugeCallable; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.util.telemetry.TelemetryGauge; -// -// class DefaultConnectionPluginTest { -// -// private DefaultConnectionPlugin plugin; -// -// @Mock PluginService pluginService; -// @Mock ConnectionProvider connectionProvider; -// @Mock PluginManagerService pluginManagerService; -// @Mock JdbcCallable mockSqlFunction; -// @Mock JdbcCallable mockConnectFunction; -// @Mock Connection conn; -// @Mock Connection oldConn; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock TelemetryCounter mockTelemetryCounter; -// @Mock TelemetryGauge mockTelemetryGauge; -// @Mock ConnectionProviderManager mockConnectionProviderManager; -// @Mock HostSpec mockHostSpec; -// -// -// private AutoCloseable closeable; -// -// @BeforeEach -// void setUp() { -// closeable = MockitoAnnotations.openMocks(this); -// -// when(pluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); -// // noinspection unchecked -// when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); -// when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any())) -// .thenReturn(connectionProvider); -// -// plugin = new DefaultConnectionPlugin( -// pluginService, connectionProvider, null, pluginManagerService, mockConnectionProviderManager); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @ParameterizedTest -// @MethodSource("multiStatementQueries") -// void testParseMultiStatementQueries(final String sql, final List expected) { -// final List actual = plugin.parseMultiStatementQueries(sql); -// assertEquals(expected, actual); -// } -// -// @Test -// void testExecute_closeCurrentConnection() throws SQLException { -// when(this.pluginService.getCurrentConnection()).thenReturn(conn); -// plugin.execute(Void.class, SQLException.class, conn, "Connection.close", mockSqlFunction, new Object[]{}); -// verify(pluginManagerService, times(1)).setInTransaction(false); -// } -// -// @Test -// void testExecute_closeOldConnection() throws SQLException { -// when(this.pluginService.getCurrentConnection()).thenReturn(conn); -// plugin.execute(Void.class, SQLException.class, oldConn, "Connection.close", mockSqlFunction, new Object[]{}); -// verify(pluginManagerService, never()).setInTransaction(anyBoolean()); -// } -// -// @Test -// void testConnect() throws SQLException { -// plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction); -// verify(connectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any(), any()); -// verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any()); -// } -// -// private static Stream multiStatementQueries() { -// return Stream.of( -// Arguments.of("", new ArrayList()), -// Arguments.of(null, new ArrayList()), -// Arguments.of(" ", new ArrayList()), -// Arguments.of("some \t \r \n query;", Collections.singletonList("some query")), -// Arguments.of("some\t\t\r\n query;query2", Arrays.asList("some query", "query2")) -// ); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.ConnectionProviderManager; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginManagerService; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.telemetry.GaugeCallable; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; + +class DefaultConnectionPluginTest { + + private DefaultConnectionPlugin plugin; + + @Mock PluginService pluginService; + @Mock ConnectionProvider connectionProvider; + @Mock PluginManagerService pluginManagerService; + @Mock JdbcCallable mockSqlFunction; + @Mock JdbcCallable mockConnectFunction; + @Mock Connection conn; + @Mock Connection oldConn; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock TelemetryContext mockTelemetryContext; + @Mock TelemetryCounter mockTelemetryCounter; + @Mock TelemetryGauge mockTelemetryGauge; + @Mock ConnectionProviderManager mockConnectionProviderManager; + @Mock HostSpec mockHostSpec; + + + private AutoCloseable closeable; + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + + when(pluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); + // noinspection unchecked + when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); + when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any())) + .thenReturn(connectionProvider); + + plugin = new DefaultConnectionPlugin( + pluginService, connectionProvider, null, pluginManagerService, mockConnectionProviderManager); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @ParameterizedTest + @MethodSource("multiStatementQueries") + void testParseMultiStatementQueries(final String sql, final List expected) { + final List actual = plugin.parseMultiStatementQueries(sql); + assertEquals(expected, actual); + } + + @Test + void testExecute_closeCurrentConnection() throws SQLException { + when(this.pluginService.getCurrentConnection()).thenReturn(conn); + plugin.execute(Void.class, SQLException.class, conn, "Connection.close", mockSqlFunction, new Object[]{}); + verify(pluginManagerService, times(1)).setInTransaction(false); + } + + @Test + void testExecute_closeOldConnection() throws SQLException { + when(this.pluginService.getCurrentConnection()).thenReturn(conn); + plugin.execute(Void.class, SQLException.class, oldConn, "Connection.close", mockSqlFunction, new Object[]{}); + verify(pluginManagerService, never()).setInTransaction(anyBoolean()); + } + + @Test + void testConnect() throws SQLException { + plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction); + verify(connectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any(), any()); + verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any()); + } + + private static Stream multiStatementQueries() { + return Stream.of( + Arguments.of("", new ArrayList()), + Arguments.of(null, new ArrayList()), + Arguments.of(" ", new ArrayList()), + Arguments.of("some \t \r \n query;", Collections.singletonList("some query")), + Arguments.of("some\t\t\r\n query;query2", Arrays.asList("some query", "query2")) + ); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java index e465df7f9..0d41c5f72 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java @@ -1,159 +1,159 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.customendpoint; -// -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// import static software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.HashSet; -// import java.util.Properties; -// import java.util.function.BiFunction; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.awssdk.regions.Region; -// import software.amazon.awssdk.services.rds.RdsClient; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.monitoring.MonitorService; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// -// public class CustomEndpointPluginTest { -// private final String writerClusterUrl = "writer.cluster-XYZ.us-east-1.rds.amazonaws.com"; -// private final String customEndpointUrl = "custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; -// -// private AutoCloseable closeable; -// private final Properties props = new Properties(); -// private final HostAvailabilityStrategy availabilityStrategy = new SimpleHostAvailabilityStrategy(); -// private final HostSpecBuilder hostSpecBuilder = new HostSpecBuilder(availabilityStrategy); -// private final HostSpec writerClusterHost = hostSpecBuilder.host(writerClusterUrl).build(); -// private final HostSpec host = hostSpecBuilder.host(customEndpointUrl).build(); -// -// @Mock private FullServicesContainer mockServicesContainer; -// @Mock private PluginService mockPluginService; -// @Mock private MonitorService mockMonitorService; -// @Mock private BiFunction mockRdsClientFunc; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock private TelemetryCounter mockTelemetryCounter; -// @Mock private JdbcCallable mockConnectFunc; -// @Mock private JdbcCallable mockJdbcMethodFunc; -// @Mock private Connection mockConnection; -// @Mock private CustomEndpointMonitor mockMonitor; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// -// -// @BeforeEach -// public void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// -// when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); -// when(mockServicesContainer.getMonitorService()).thenReturn(mockMonitorService); -// when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.createCounter(any(String.class))).thenReturn(mockTelemetryCounter); -// when(mockMonitor.hasCustomEndpointInfo()).thenReturn(true); -// when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); -// when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// props.clear(); -// } -// -// private CustomEndpointPlugin getSpyPlugin() throws SQLException { -// CustomEndpointPlugin plugin = new CustomEndpointPlugin(mockServicesContainer, props, mockRdsClientFunc); -// CustomEndpointPlugin spyPlugin = spy(plugin); -// doReturn(mockMonitor).when(spyPlugin).createMonitorIfAbsent(any(Properties.class)); -// return spyPlugin; -// } -// -// @Test -// public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { -// CustomEndpointPlugin spyPlugin = getSpyPlugin(); -// -// spyPlugin.connect("", writerClusterHost, props, true, mockConnectFunc); -// -// verify(mockConnectFunc, times(1)).call(); -// verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); -// } -// -// @Test -// public void testConnect_monitorCreated() throws SQLException { -// CustomEndpointPlugin spyPlugin = getSpyPlugin(); -// -// spyPlugin.connect("", host, props, true, mockConnectFunc); -// -// verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); -// verify(mockConnectFunc, times(1)).call(); -// } -// -// @Test -// public void testConnect_timeoutWaitingForInfo() throws SQLException { -// WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.set(props, "1"); -// CustomEndpointPlugin spyPlugin = getSpyPlugin(); -// when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); -// -// assertThrows(SQLException.class, () -> spyPlugin.connect("", host, props, true, mockConnectFunc)); -// -// verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); -// verify(mockConnectFunc, never()).call(); -// } -// -// @Test -// public void testExecute_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { -// CustomEndpointPlugin spyPlugin = getSpyPlugin(); -// -// spyPlugin.execute( -// Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); -// -// verify(mockJdbcMethodFunc, times(1)).call(); -// verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); -// } -// -// @Test -// public void testExecute_monitorCreated() throws SQLException { -// CustomEndpointPlugin spyPlugin = getSpyPlugin(); -// spyPlugin.customEndpointHostSpec = host; -// -// spyPlugin.execute( -// Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); -// -// verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); -// verify(mockJdbcMethodFunc, times(1)).call(); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.HashSet; +import java.util.Properties; +import java.util.function.BiFunction; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.monitoring.MonitorService; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +public class CustomEndpointPluginTest { + private final String writerClusterUrl = "writer.cluster-XYZ.us-east-1.rds.amazonaws.com"; + private final String customEndpointUrl = "custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; + + private AutoCloseable closeable; + private final Properties props = new Properties(); + private final HostAvailabilityStrategy availabilityStrategy = new SimpleHostAvailabilityStrategy(); + private final HostSpecBuilder hostSpecBuilder = new HostSpecBuilder(availabilityStrategy); + private final HostSpec writerClusterHost = hostSpecBuilder.host(writerClusterUrl).build(); + private final HostSpec host = hostSpecBuilder.host(customEndpointUrl).build(); + + @Mock private FullServicesContainer mockServicesContainer; + @Mock private PluginService mockPluginService; + @Mock private MonitorService mockMonitorService; + @Mock private BiFunction mockRdsClientFunc; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryCounter mockTelemetryCounter; + @Mock private JdbcCallable mockConnectFunc; + @Mock private JdbcCallable mockJdbcMethodFunc; + @Mock private Connection mockConnection; + @Mock private CustomEndpointMonitor mockMonitor; + @Mock TargetDriverDialect mockTargetDriverDialect; + + + @BeforeEach + public void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + + when(mockServicesContainer.getPluginService()).thenReturn(mockPluginService); + when(mockServicesContainer.getMonitorService()).thenReturn(mockMonitorService); + when(mockServicesContainer.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(any(String.class))).thenReturn(mockTelemetryCounter); + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(true); + when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); + when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + props.clear(); + } + + private CustomEndpointPlugin getSpyPlugin() throws SQLException { + CustomEndpointPlugin plugin = new CustomEndpointPlugin(mockServicesContainer, props, mockRdsClientFunc); + CustomEndpointPlugin spyPlugin = spy(plugin); + doReturn(mockMonitor).when(spyPlugin).createMonitorIfAbsent(any(Properties.class)); + return spyPlugin; + } + + @Test + public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + + spyPlugin.connect("", writerClusterHost, props, true, mockConnectFunc); + + verify(mockConnectFunc, times(1)).call(); + verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); + } + + @Test + public void testConnect_monitorCreated() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + + spyPlugin.connect("", host, props, true, mockConnectFunc); + + verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); + verify(mockConnectFunc, times(1)).call(); + } + + @Test + public void testConnect_timeoutWaitingForInfo() throws SQLException { + WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.set(props, "1"); + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); + + assertThrows(SQLException.class, () -> spyPlugin.connect("", host, props, true, mockConnectFunc)); + + verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); + verify(mockConnectFunc, never()).call(); + } + + @Test + public void testExecute_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + + spyPlugin.execute( + Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); + + verify(mockJdbcMethodFunc, times(1)).call(); + verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); + } + + @Test + public void testExecute_monitorCreated() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + spyPlugin.customEndpointHostSpec = host; + + spyPlugin.execute( + Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); + + verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); + verify(mockJdbcMethodFunc, times(1)).call(); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java index 377b511eb..5c0293bd3 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java @@ -1,356 +1,356 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.dev; -// -// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -// import static org.junit.jupiter.api.Assertions.assertInstanceOf; -// import static org.junit.jupiter.api.Assertions.assertNotNull; -// import static org.junit.jupiter.api.Assertions.assertNotSame; -// import static org.junit.jupiter.api.Assertions.assertSame; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyBoolean; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.ConnectionPluginManager; -// import software.amazon.jdbc.ConnectionProvider; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.dialect.DialectCodes; -// import software.amazon.jdbc.dialect.DialectManager; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.FullServicesContainerImpl; -// import software.amazon.jdbc.util.monitoring.MonitorService; -// import software.amazon.jdbc.util.storage.StorageService; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.wrapper.ConnectionWrapper; -// -// @SuppressWarnings({"resource"}) -// public class DeveloperConnectionPluginTest { -// private FullServicesContainer servicesContainer; -// @Mock StorageService mockStorageService; -// @Mock MonitorService mockMonitorService; -// @Mock ConnectionProvider mockConnectionProvider; -// @Mock Connection mockConnection; -// @Mock ConnectionPluginManager mockConnectionPluginManager; -// @Mock ExceptionSimulatorConnectCallback mockConnectCallback; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// -// private AutoCloseable closeable; -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @BeforeEach -// void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// servicesContainer = new FullServicesContainerImpl( -// mockStorageService, mockMonitorService, mockConnectionProvider, mockTelemetryFactory); -// -// when(mockConnectionProvider.connect(any(), any())).thenReturn(mockConnection); -// when(mockConnectCallback.getExceptionToRaise(any(), any(), anyBoolean())).thenReturn(null); -// -// when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); -// when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); -// } -// -// @Test -// @SuppressWarnings("try") -// public void test_RaiseException() throws SQLException { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// try (ConnectionWrapper wrapper = new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)) { -// -// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); -// assertNotNull(simulator); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// -// final RuntimeException runtimeException = new RuntimeException("test"); -// simulator.raiseExceptionOnNextCall(runtimeException); -// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); -// assertSame(runtimeException, thrownException); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// } -// } -// -// @Test -// public void test_RaiseExceptionForMethodName() throws SQLException { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// try (ConnectionWrapper wrapper = new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)) { -// -// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); -// assertNotNull(simulator); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// -// final RuntimeException runtimeException = new RuntimeException("test"); -// simulator.raiseExceptionOnNextCall("Connection.createStatement", runtimeException); -// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); -// assertSame(runtimeException, thrownException); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// } -// } -// -// @Test -// public void test_RaiseExceptionForAnyMethodName() throws SQLException { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// try (ConnectionWrapper wrapper = new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)) { -// -// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); -// assertNotNull(simulator); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// -// final RuntimeException runtimeException = new RuntimeException("test"); -// simulator.raiseExceptionOnNextCall("*", runtimeException); -// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); -// assertSame(runtimeException, thrownException); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// } -// } -// -// @Test -// public void test_RaiseExceptionForWrongMethodName() throws SQLException { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// try (ConnectionWrapper wrapper = new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)) { -// -// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); -// assertNotNull(simulator); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// -// final RuntimeException runtimeException = new RuntimeException("test"); -// simulator.raiseExceptionOnNextCall("Connection.isClosed", runtimeException); -// assertDoesNotThrow(() -> wrapper.createStatement()); -// -// Throwable thrownException = assertThrows(RuntimeException.class, wrapper::isClosed); -// assertSame(runtimeException, thrownException); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// } -// } -// -// @Test -// public void test_RaiseExpectedExceptionClass() throws SQLException { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// try (ConnectionWrapper wrapper = new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)) { -// -// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); -// assertNotNull(simulator); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// -// final SQLException sqlException = new SQLException("test"); -// simulator.raiseExceptionOnNextCall(sqlException); -// Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); -// assertSame(sqlException, thrownException); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// } -// } -// -// @Test -// public void test_RaiseUnexpectedExceptionClass() throws SQLException { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// try (ConnectionWrapper wrapper = new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)) { -// -// ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); -// assertNotNull(simulator); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// -// final Exception exception = new Exception("test"); -// simulator.raiseExceptionOnNextCall(exception); -// Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); -// assertNotNull(thrownException); -// assertNotSame(exception, thrownException); -// assertInstanceOf(SQLException.class, thrownException); -// assertNotNull(thrownException.getCause()); -// assertSame(thrownException.getCause(), exception); -// -// assertDoesNotThrow(() -> wrapper.createStatement()); -// } -// } -// -// @Test -// public void test_RaiseExceptionOnConnect() { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// -// final SQLException exception = new SQLException("test"); -// ExceptionSimulatorManager.raiseExceptionOnNextConnect(exception); -// -// Throwable thrownException = assertThrows( -// SQLException.class, -// () -> new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)); -// assertSame(exception, thrownException); -// -// assertDoesNotThrow( -// () -> new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)); -// } -// -// @Test -// public void test_NoExceptionOnConnectWithCallback() { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// -// ExceptionSimulatorManager.setCallback(mockConnectCallback); -// -// assertDoesNotThrow( -// () -> new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)); -// } -// -// @Test -// public void test_RaiseExceptionOnConnectWithCallback() { -// -// final Properties props = new Properties(); -// props.put(PropertyDefinition.PLUGINS.name, "dev"); -// props.put(DialectManager.DIALECT.name, DialectCodes.PG); -// -// final SQLException exception = new SQLException("test"); -// when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())) -// .thenReturn(exception) -// .thenReturn(null); -// ExceptionSimulatorManager.setCallback(mockConnectCallback); -// -// Throwable thrownException = assertThrows( -// SQLException.class, -// () -> new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)); -// assertSame(exception, thrownException); -// -// assertDoesNotThrow( -// () -> new ConnectionWrapper( -// servicesContainer, -// props, -// "any-protocol://any-host/", -// mockConnectionProvider, -// null, -// mockTargetDriverDialect, -// null)); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.dev; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.DialectCodes; +import software.amazon.jdbc.dialect.DialectManager; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.FullServicesContainerImpl; +import software.amazon.jdbc.util.monitoring.MonitorService; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.wrapper.ConnectionWrapper; + +@SuppressWarnings({"resource"}) +public class DeveloperConnectionPluginTest { + private FullServicesContainer servicesContainer; + @Mock StorageService mockStorageService; + @Mock MonitorService mockMonitorService; + @Mock ConnectionProvider mockConnectionProvider; + @Mock Connection mockConnection; + @Mock ConnectionPluginManager mockConnectionPluginManager; + @Mock ExceptionSimulatorConnectCallback mockConnectCallback; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock TelemetryContext mockTelemetryContext; + @Mock TargetDriverDialect mockTargetDriverDialect; + + private AutoCloseable closeable; + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @BeforeEach + void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + servicesContainer = new FullServicesContainerImpl( + mockStorageService, mockMonitorService, mockConnectionProvider, mockTelemetryFactory); + + when(mockConnectionProvider.connect(any(), any())).thenReturn(mockConnection); + when(mockConnectCallback.getExceptionToRaise(any(), any(), anyBoolean())).thenReturn(null); + + when(mockConnectionPluginManager.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); + when(mockTelemetryFactory.openTelemetryContext(eq(null), any())).thenReturn(mockTelemetryContext); + } + + @Test + @SuppressWarnings("try") + public void test_RaiseException() throws SQLException { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + try (ConnectionWrapper wrapper = new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)) { + + ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); + assertNotNull(simulator); + + assertDoesNotThrow(() -> wrapper.createStatement()); + + final RuntimeException runtimeException = new RuntimeException("test"); + simulator.raiseExceptionOnNextCall(runtimeException); + Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); + assertSame(runtimeException, thrownException); + + assertDoesNotThrow(() -> wrapper.createStatement()); + } + } + + @Test + public void test_RaiseExceptionForMethodName() throws SQLException { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + try (ConnectionWrapper wrapper = new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)) { + + ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); + assertNotNull(simulator); + + assertDoesNotThrow(() -> wrapper.createStatement()); + + final RuntimeException runtimeException = new RuntimeException("test"); + simulator.raiseExceptionOnNextCall("Connection.createStatement", runtimeException); + Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); + assertSame(runtimeException, thrownException); + + assertDoesNotThrow(() -> wrapper.createStatement()); + } + } + + @Test + public void test_RaiseExceptionForAnyMethodName() throws SQLException { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + try (ConnectionWrapper wrapper = new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)) { + + ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); + assertNotNull(simulator); + + assertDoesNotThrow(() -> wrapper.createStatement()); + + final RuntimeException runtimeException = new RuntimeException("test"); + simulator.raiseExceptionOnNextCall("*", runtimeException); + Throwable thrownException = assertThrows(RuntimeException.class, wrapper::createStatement); + assertSame(runtimeException, thrownException); + + assertDoesNotThrow(() -> wrapper.createStatement()); + } + } + + @Test + public void test_RaiseExceptionForWrongMethodName() throws SQLException { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + try (ConnectionWrapper wrapper = new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)) { + + ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); + assertNotNull(simulator); + + assertDoesNotThrow(() -> wrapper.createStatement()); + + final RuntimeException runtimeException = new RuntimeException("test"); + simulator.raiseExceptionOnNextCall("Connection.isClosed", runtimeException); + assertDoesNotThrow(() -> wrapper.createStatement()); + + Throwable thrownException = assertThrows(RuntimeException.class, wrapper::isClosed); + assertSame(runtimeException, thrownException); + + assertDoesNotThrow(() -> wrapper.createStatement()); + } + } + + @Test + public void test_RaiseExpectedExceptionClass() throws SQLException { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + try (ConnectionWrapper wrapper = new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)) { + + ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); + assertNotNull(simulator); + + assertDoesNotThrow(() -> wrapper.createStatement()); + + final SQLException sqlException = new SQLException("test"); + simulator.raiseExceptionOnNextCall(sqlException); + Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); + assertSame(sqlException, thrownException); + + assertDoesNotThrow(() -> wrapper.createStatement()); + } + } + + @Test + public void test_RaiseUnexpectedExceptionClass() throws SQLException { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + try (ConnectionWrapper wrapper = new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)) { + + ExceptionSimulator simulator = wrapper.unwrap(ExceptionSimulator.class); + assertNotNull(simulator); + + assertDoesNotThrow(() -> wrapper.createStatement()); + + final Exception exception = new Exception("test"); + simulator.raiseExceptionOnNextCall(exception); + Throwable thrownException = assertThrows(SQLException.class, wrapper::createStatement); + assertNotNull(thrownException); + assertNotSame(exception, thrownException); + assertInstanceOf(SQLException.class, thrownException); + assertNotNull(thrownException.getCause()); + assertSame(thrownException.getCause(), exception); + + assertDoesNotThrow(() -> wrapper.createStatement()); + } + } + + @Test + public void test_RaiseExceptionOnConnect() { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + + final SQLException exception = new SQLException("test"); + ExceptionSimulatorManager.raiseExceptionOnNextConnect(exception); + + Throwable thrownException = assertThrows( + SQLException.class, + () -> new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)); + assertSame(exception, thrownException); + + assertDoesNotThrow( + () -> new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)); + } + + @Test + public void test_NoExceptionOnConnectWithCallback() { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + + ExceptionSimulatorManager.setCallback(mockConnectCallback); + + assertDoesNotThrow( + () -> new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)); + } + + @Test + public void test_RaiseExceptionOnConnectWithCallback() { + + final Properties props = new Properties(); + props.put(PropertyDefinition.PLUGINS.name, "dev"); + props.put(DialectManager.DIALECT.name, DialectCodes.PG); + + final SQLException exception = new SQLException("test"); + when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())) + .thenReturn(exception) + .thenReturn(null); + ExceptionSimulatorManager.setCallback(mockConnectCallback); + + Throwable thrownException = assertThrows( + SQLException.class, + () -> new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)); + assertSame(exception, thrownException); + + assertDoesNotThrow( + () -> new ConnectionWrapper( + servicesContainer, + props, + "any-protocol://any-host/", + mockConnectionProvider, + null, + mockTargetDriverDialect, + null)); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java index 8b8d8dbb5..63320e95d 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java @@ -1,344 +1,344 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.efm; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertNotNull; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.anyInt; -// import static org.mockito.ArgumentMatchers.anySet; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.any; -// import static org.mockito.Mockito.atMostOnce; -// import static org.mockito.Mockito.doThrow; -// import static org.mockito.Mockito.mock; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.EnumSet; -// import java.util.HashSet; -// import java.util.Properties; -// import java.util.Set; -// import java.util.concurrent.locks.ReentrantLock; -// import java.util.function.Supplier; -// import java.util.stream.Stream; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.junit.jupiter.params.ParameterizedTest; -// import org.junit.jupiter.params.provider.Arguments; -// import org.junit.jupiter.params.provider.MethodSource; -// import org.mockito.ArgumentCaptor; -// import org.mockito.Captor; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.JdbcMethod; -// import software.amazon.jdbc.NodeChangeOptions; -// import software.amazon.jdbc.OldConnectionSuggestedAction; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.HostAvailability; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.Messages; -// import software.amazon.jdbc.util.RdsUrlType; -// import software.amazon.jdbc.util.RdsUtils; -// -// class HostMonitoringConnectionPluginTest { -// -// static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; -// static final String MONITOR_METHOD_NAME = JdbcMethod.STATEMENT_EXECUTEQUERY.methodName; -// static final String NO_MONITOR_METHOD_NAME = JdbcMethod.CONNECTION_ABORT.methodName; -// static final int FAILURE_DETECTION_TIME = 10; -// static final int FAILURE_DETECTION_INTERVAL = 100; -// static final int FAILURE_DETECTION_COUNT = 5; -// private static final Object[] EMPTY_ARGS = {}; -// @Mock PluginService pluginService; -// @Mock Dialect mockDialect; -// @Mock Connection connection; -// @Mock Statement statement; -// @Mock ResultSet resultSet; -// @Captor ArgumentCaptor stringArgumentCaptor; -// Properties properties = new Properties(); -// @Mock HostSpec hostSpec; -// @Mock HostSpec hostSpec2; -// @Mock Supplier supplier; -// @Mock RdsUtils rdsUtils; -// @Mock HostMonitorConnectionContext context; -// @Mock ReentrantLock mockReentrantLock; -// @Mock HostMonitorService monitorService; -// @Mock JdbcCallable sqlFunction; -// @Mock TargetDriverDialect targetDriverDialect; -// -// private HostMonitoringConnectionPlugin plugin; -// private AutoCloseable closeable; -// -// /** -// * Generate different sets of method arguments where one argument is null to ensure {@link -// * software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin#HostMonitoringConnectionPlugin(PluginService, -// * Properties)} can handle null arguments correctly. -// * -// * @return different sets of arguments. -// */ -// private static Stream generateNullArguments() { -// final PluginService pluginService = mock(PluginService.class); -// final Properties properties = new Properties(); -// -// return Stream.of( -// Arguments.of(null, null), -// Arguments.of(pluginService, null), -// Arguments.of(null, properties)); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @BeforeEach -// void init() throws Exception { -// closeable = MockitoAnnotations.openMocks(this); -// initDefaultMockReturns(); -// properties.clear(); -// } -// -// void initDefaultMockReturns() throws Exception { -// when(supplier.get()).thenReturn(monitorService); -// when(monitorService.startMonitoring( -// any(Connection.class), -// anySet(), -// any(HostSpec.class), -// any(Properties.class), -// anyInt(), -// anyInt(), -// anyInt())) -// .thenReturn(context); -// when(context.getLock()).thenReturn(mockReentrantLock); -// -// when(pluginService.getCurrentConnection()).thenReturn(connection); -// when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); -// when(pluginService.getDialect()).thenReturn(mockDialect); -// when(pluginService.getTargetDriverDialect()).thenReturn(targetDriverDialect); -// when(targetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn( -// new HashSet<>(Collections.singletonList(MONITOR_METHOD_NAME))); -// when(mockDialect.getHostAliasQuery()).thenReturn("any"); -// when(hostSpec.getHost()).thenReturn("host"); -// when(hostSpec.getHost()).thenReturn("port"); -// when(hostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); -// when(hostSpec2.getHost()).thenReturn("host"); -// when(hostSpec2.getHost()).thenReturn("port"); -// when(hostSpec2.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); -// when(connection.createStatement()).thenReturn(statement); -// when(statement.executeQuery(any())).thenReturn(resultSet); -// when(rdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); -// -// properties.put("failureDetectionEnabled", Boolean.TRUE.toString()); -// properties.put("failureDetectionTime", String.valueOf(FAILURE_DETECTION_TIME)); -// properties.put("failureDetectionInterval", String.valueOf(FAILURE_DETECTION_INTERVAL)); -// properties.put("failureDetectionCount", String.valueOf(FAILURE_DETECTION_COUNT)); -// } -// -// private void initializePlugin() { -// plugin = new HostMonitoringConnectionPlugin(pluginService, properties, supplier, rdsUtils); -// } -// -// @Test -// void test_executeWithMonitoringDisabled() throws Exception { -// properties.put("failureDetectionEnabled", Boolean.FALSE.toString()); -// -// initializePlugin(); -// -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// sqlFunction, -// EMPTY_ARGS); -// -// verify(supplier, never()).get(); -// verify(monitorService, never()) -// .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); -// verify(monitorService, never()).stopMonitoring(context); -// verify(sqlFunction, times(1)).call(); -// } -// -// @Test -// void test_executeWithNoNeedToMonitor() throws Exception { -// -// initializePlugin(); -// -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// NO_MONITOR_METHOD_NAME, -// sqlFunction, -// EMPTY_ARGS); -// -// verify(supplier, atMostOnce()).get(); -// verify(monitorService, never()) -// .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); -// verify(monitorService, never()).stopMonitoring(context); -// verify(sqlFunction, times(1)).call(); -// } -// -// @Test -// void test_executeMonitoringEnabled() throws Exception { -// -// initializePlugin(); -// -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// sqlFunction, -// EMPTY_ARGS); -// -// verify(supplier, times(1)).get(); -// verify(monitorService, times(1)) -// .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); -// verify(monitorService, times(1)).stopMonitoring(context); -// verify(sqlFunction, times(1)).call(); -// } -// -// /** -// * Tests exception being thrown in the finally block when checking connection status in the execute method. -// */ -// @Test -// void test_executeCleanUp_whenCheckingConnectionStatus_throwsException() throws SQLException { -// initializePlugin(); -// -// final SQLException expectedException = new SQLException("exception thrown during isClosed"); -// when(context.isNodeUnhealthy()).thenReturn(true); -// doThrow(expectedException).when(connection).isClosed(); -// final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// sqlFunction, -// EMPTY_ARGS)); -// -// assertEquals(expectedException, actualException); -// } -// -// /** -// * Tests exception being thrown in the finally block -// * when an open connection object is detected for an unavailable node in the execute method. -// */ -// @Test -// void test_executeCleanUp_whenAbortConnection_throwsException() throws SQLException { -// initializePlugin(); -// -// final String errorMessage = Messages.get( -// "HostMonitoringConnectionPlugin.unavailableNode", -// new Object[] {"alias"}); -// -// when(hostSpec.asAlias()).thenReturn("alias"); -// when(connection.isClosed()).thenReturn(false); -// when(context.isNodeUnhealthy()).thenReturn(true); -// final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// sqlFunction, -// EMPTY_ARGS)); -// -// assertEquals(errorMessage, actualException.getMessage()); -// verify(pluginService).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); -// verify(connection).close(); -// } -// -// @Test -// void test_connect_exceptionRaisedDuringGenerateHostAliases() throws SQLException { -// initializePlugin(); -// -// doThrow(new SQLException()).when(connection).createStatement(); -// -// // Ensure SQLException raised in `generateHostAliases` are ignored. -// final Connection conn = plugin.connect("protocol", hostSpec, properties, true, () -> connection); -// assertNotNull(conn); -// } -// -// @ParameterizedTest -// @MethodSource("nodeChangeOptions") -// void test_notifyConnectionChanged_nodeWentDown(final NodeChangeOptions option) throws SQLException { -// initializePlugin(); -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// sqlFunction, -// EMPTY_ARGS); -// -// final Set aliases1 = new HashSet<>(Arrays.asList("alias1", "alias2")); -// final Set aliases2 = new HashSet<>(Arrays.asList("alias3", "alias4")); -// when(hostSpec.asAliases()).thenReturn(aliases1); -// when(hostSpec2.asAliases()).thenReturn(aliases2); -// when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); -// -// assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); -// // NodeKeys should contain {"alias1", "alias2"} -// verify(monitorService).stopMonitoringForAllConnections(aliases1); -// -// when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec2); -// assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); -// // NotifyConnectionChanged should reset the monitoringHostSpec. -// // NodeKeys should contain {"alias3", "alias4"} -// verify(monitorService).stopMonitoringForAllConnections(aliases2); -// } -// -// @Test -// void test_releaseResources() throws SQLException { -// initializePlugin(); -// -// // Test releaseResources when the monitor service has not been initialized. -// plugin.releaseResources(); -// verify(monitorService, never()).releaseResources(); -// -// // Test releaseResources when the monitor service has been initialized. -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// MONITOR_METHOD_INVOKE_ON, -// MONITOR_METHOD_NAME, -// sqlFunction, -// EMPTY_ARGS); -// plugin.releaseResources(); -// verify(monitorService).releaseResources(); -// } -// -// static Stream nodeChangeOptions() { -// return Stream.of( -// Arguments.of(NodeChangeOptions.WENT_DOWN), -// Arguments.of(NodeChangeOptions.NODE_DELETED) -// ); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.efm; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anySet; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atMostOnce; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; +import software.amazon.jdbc.NodeChangeOptions; +import software.amazon.jdbc.OldConnectionSuggestedAction; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.RdsUrlType; +import software.amazon.jdbc.util.RdsUtils; + +class HostMonitoringConnectionPluginTest { + + static final Class MONITOR_METHOD_INVOKE_ON = Connection.class; + static final String MONITOR_METHOD_NAME = JdbcMethod.STATEMENT_EXECUTEQUERY.methodName; + static final String NO_MONITOR_METHOD_NAME = JdbcMethod.CONNECTION_ABORT.methodName; + static final int FAILURE_DETECTION_TIME = 10; + static final int FAILURE_DETECTION_INTERVAL = 100; + static final int FAILURE_DETECTION_COUNT = 5; + private static final Object[] EMPTY_ARGS = {}; + @Mock PluginService pluginService; + @Mock Dialect mockDialect; + @Mock Connection connection; + @Mock Statement statement; + @Mock ResultSet resultSet; + @Captor ArgumentCaptor stringArgumentCaptor; + Properties properties = new Properties(); + @Mock HostSpec hostSpec; + @Mock HostSpec hostSpec2; + @Mock Supplier supplier; + @Mock RdsUtils rdsUtils; + @Mock HostMonitorConnectionContext context; + @Mock ReentrantLock mockReentrantLock; + @Mock HostMonitorService monitorService; + @Mock JdbcCallable sqlFunction; + @Mock TargetDriverDialect targetDriverDialect; + + private HostMonitoringConnectionPlugin plugin; + private AutoCloseable closeable; + + /** + * Generate different sets of method arguments where one argument is null to ensure {@link + * software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin#HostMonitoringConnectionPlugin(PluginService, + * Properties)} can handle null arguments correctly. + * + * @return different sets of arguments. + */ + private static Stream generateNullArguments() { + final PluginService pluginService = mock(PluginService.class); + final Properties properties = new Properties(); + + return Stream.of( + Arguments.of(null, null), + Arguments.of(pluginService, null), + Arguments.of(null, properties)); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @BeforeEach + void init() throws Exception { + closeable = MockitoAnnotations.openMocks(this); + initDefaultMockReturns(); + properties.clear(); + } + + void initDefaultMockReturns() throws Exception { + when(supplier.get()).thenReturn(monitorService); + when(monitorService.startMonitoring( + any(Connection.class), + anySet(), + any(HostSpec.class), + any(Properties.class), + anyInt(), + anyInt(), + anyInt())) + .thenReturn(context); + when(context.getLock()).thenReturn(mockReentrantLock); + + when(pluginService.getCurrentConnection()).thenReturn(connection); + when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); + when(pluginService.getDialect()).thenReturn(mockDialect); + when(pluginService.getTargetDriverDialect()).thenReturn(targetDriverDialect); + when(targetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn( + new HashSet<>(Collections.singletonList(MONITOR_METHOD_NAME))); + when(mockDialect.getHostAliasQuery()).thenReturn("any"); + when(hostSpec.getHost()).thenReturn("host"); + when(hostSpec.getHost()).thenReturn("port"); + when(hostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); + when(hostSpec2.getHost()).thenReturn("host"); + when(hostSpec2.getHost()).thenReturn("port"); + when(hostSpec2.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); + when(connection.createStatement()).thenReturn(statement); + when(statement.executeQuery(any())).thenReturn(resultSet); + when(rdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); + + properties.put("failureDetectionEnabled", Boolean.TRUE.toString()); + properties.put("failureDetectionTime", String.valueOf(FAILURE_DETECTION_TIME)); + properties.put("failureDetectionInterval", String.valueOf(FAILURE_DETECTION_INTERVAL)); + properties.put("failureDetectionCount", String.valueOf(FAILURE_DETECTION_COUNT)); + } + + private void initializePlugin() { + plugin = new HostMonitoringConnectionPlugin(pluginService, properties, supplier, rdsUtils); + } + + @Test + void test_executeWithMonitoringDisabled() throws Exception { + properties.put("failureDetectionEnabled", Boolean.FALSE.toString()); + + initializePlugin(); + + plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + sqlFunction, + EMPTY_ARGS); + + verify(supplier, never()).get(); + verify(monitorService, never()) + .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); + verify(monitorService, never()).stopMonitoring(context); + verify(sqlFunction, times(1)).call(); + } + + @Test + void test_executeWithNoNeedToMonitor() throws Exception { + + initializePlugin(); + + plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + NO_MONITOR_METHOD_NAME, + sqlFunction, + EMPTY_ARGS); + + verify(supplier, atMostOnce()).get(); + verify(monitorService, never()) + .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); + verify(monitorService, never()).stopMonitoring(context); + verify(sqlFunction, times(1)).call(); + } + + @Test + void test_executeMonitoringEnabled() throws Exception { + + initializePlugin(); + + plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + sqlFunction, + EMPTY_ARGS); + + verify(supplier, times(1)).get(); + verify(monitorService, times(1)) + .startMonitoring(any(), any(), any(), any(), anyInt(), anyInt(), anyInt()); + verify(monitorService, times(1)).stopMonitoring(context); + verify(sqlFunction, times(1)).call(); + } + + /** + * Tests exception being thrown in the finally block when checking connection status in the execute method. + */ + @Test + void test_executeCleanUp_whenCheckingConnectionStatus_throwsException() throws SQLException { + initializePlugin(); + + final SQLException expectedException = new SQLException("exception thrown during isClosed"); + when(context.isNodeUnhealthy()).thenReturn(true); + doThrow(expectedException).when(connection).isClosed(); + final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + sqlFunction, + EMPTY_ARGS)); + + assertEquals(expectedException, actualException); + } + + /** + * Tests exception being thrown in the finally block + * when an open connection object is detected for an unavailable node in the execute method. + */ + @Test + void test_executeCleanUp_whenAbortConnection_throwsException() throws SQLException { + initializePlugin(); + + final String errorMessage = Messages.get( + "HostMonitoringConnectionPlugin.unavailableNode", + new Object[] {"alias"}); + + when(hostSpec.asAlias()).thenReturn("alias"); + when(connection.isClosed()).thenReturn(false); + when(context.isNodeUnhealthy()).thenReturn(true); + final SQLException actualException = assertThrows(SQLException.class, () -> plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + sqlFunction, + EMPTY_ARGS)); + + assertEquals(errorMessage, actualException.getMessage()); + verify(pluginService).setAvailability(any(), eq(HostAvailability.NOT_AVAILABLE)); + verify(connection).close(); + } + + @Test + void test_connect_exceptionRaisedDuringGenerateHostAliases() throws SQLException { + initializePlugin(); + + doThrow(new SQLException()).when(connection).createStatement(); + + // Ensure SQLException raised in `generateHostAliases` are ignored. + final Connection conn = plugin.connect("protocol", hostSpec, properties, true, () -> connection); + assertNotNull(conn); + } + + @ParameterizedTest + @MethodSource("nodeChangeOptions") + void test_notifyConnectionChanged_nodeWentDown(final NodeChangeOptions option) throws SQLException { + initializePlugin(); + plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + sqlFunction, + EMPTY_ARGS); + + final Set aliases1 = new HashSet<>(Arrays.asList("alias1", "alias2")); + final Set aliases2 = new HashSet<>(Arrays.asList("alias3", "alias4")); + when(hostSpec.asAliases()).thenReturn(aliases1); + when(hostSpec2.asAliases()).thenReturn(aliases2); + when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); + + assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); + // NodeKeys should contain {"alias1", "alias2"} + verify(monitorService).stopMonitoringForAllConnections(aliases1); + + when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec2); + assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); + // NotifyConnectionChanged should reset the monitoringHostSpec. + // NodeKeys should contain {"alias3", "alias4"} + verify(monitorService).stopMonitoringForAllConnections(aliases2); + } + + @Test + void test_releaseResources() throws SQLException { + initializePlugin(); + + // Test releaseResources when the monitor service has not been initialized. + plugin.releaseResources(); + verify(monitorService, never()).releaseResources(); + + // Test releaseResources when the monitor service has been initialized. + plugin.execute( + ResultSet.class, + SQLException.class, + MONITOR_METHOD_INVOKE_ON, + MONITOR_METHOD_NAME, + sqlFunction, + EMPTY_ARGS); + plugin.releaseResources(); + verify(monitorService).releaseResources(); + } + + static Stream nodeChangeOptions() { + return Stream.of( + Arguments.of(NodeChangeOptions.WENT_DOWN), + Arguments.of(NodeChangeOptions.NODE_DELETED) + ); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java index 91028405c..7b72dae6e 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java @@ -1,225 +1,225 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.federatedauth; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyInt; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.time.Instant; -// import java.util.Properties; -// import java.util.concurrent.CompletableFuture; -// import java.util.concurrent.ExecutionException; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -// import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -// import software.amazon.awssdk.regions.Region; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.plugin.TokenInfo; -// import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; -// import software.amazon.jdbc.plugin.iam.IamTokenUtility; -// import software.amazon.jdbc.util.RdsUtils; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// -// class FederatedAuthPluginTest { -// -// private static final int DEFAULT_PORT = 1234; -// private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; -// private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; -// private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; -// private static final HostSpec HOST_SPEC = -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); -// private static final String DB_USER = "iamUser"; -// private static final String TEST_TOKEN = "someTestToken"; -// private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); -// @Mock private PluginService mockPluginService; -// @Mock private Dialect mockDialect; -// @Mock JdbcCallable mockLambda; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock private TelemetryContext mockTelemetryContext; -// @Mock private TelemetryCounter mockTelemetryCounter; -// @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; -// @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; -// @Mock private RdsUtils mockRdsUtils; -// @Mock private IamTokenUtility mockIamTokenUtils; -// @Mock private CompletableFuture completableFuture; -// @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; -// private Properties props; -// private AutoCloseable closeable; -// -// @BeforeEach -// public void init() throws ExecutionException, InterruptedException, SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// props = new Properties(); -// props.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); -// props.setProperty(FederatedAuthPlugin.DB_USER.name, DB_USER); -// FederatedAuthPlugin.clearCache(); -// -// when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); -// when(mockIamTokenUtils.generateAuthenticationToken( -// any(AwsCredentialsProvider.class), -// any(Region.class), -// anyString(), -// anyInt(), -// anyString())).thenReturn(TEST_TOKEN); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); -// when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); -// when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) -// .thenReturn(mockAwsCredentialsProvider); -// when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); -// when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); -// } -// -// @AfterEach -// public void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @Test -// void testCachedToken() throws SQLException { -// FederatedAuthPlugin plugin = -// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); -// -// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; -// FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); -// -// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testExpiredCachedToken() throws SQLException { -// FederatedAuthPlugin spyPlugin = Mockito.spy( -// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); -// -// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; -// String someExpiredToken = "someExpiredToken"; -// TokenInfo expiredTokenInfo = new TokenInfo( -// someExpiredToken, Instant.now().minusMillis(300000)); -// FederatedAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); -// -// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, -// Region.US_EAST_2, -// HOST_SPEC.getHost(), -// DEFAULT_PORT, -// DB_USER); -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testNoCachedToken() throws SQLException { -// FederatedAuthPlugin spyPlugin = Mockito.spy( -// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); -// -// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// verify(mockIamTokenUtils).generateAuthenticationToken( -// mockAwsCredentialsProvider, -// Region.US_EAST_2, -// HOST_SPEC.getHost(), -// DEFAULT_PORT, -// DB_USER); -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testSpecifiedIamHostPortRegion() throws SQLException { -// final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; -// final int expectedPort = 9876; -// final Region expectedRegion = Region.US_WEST_2; -// -// props.setProperty(FederatedAuthPlugin.IAM_HOST.name, expectedHost); -// props.setProperty(FederatedAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); -// props.setProperty(FederatedAuthPlugin.IAM_REGION.name, expectedRegion.toString()); -// -// final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; -// FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); -// -// FederatedAuthPlugin plugin = -// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); -// -// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testIdpCredentialsFallback() throws SQLException { -// String expectedUser = "expectedUser"; -// String expectedPassword = "expectedPassword"; -// PropertyDefinition.USER.set(props, expectedUser); -// PropertyDefinition.PASSWORD.set(props, expectedPassword); -// -// FederatedAuthPlugin plugin = -// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); -// -// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; -// FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); -// -// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// assertEquals(expectedUser, FederatedAuthPlugin.IDP_USERNAME.getString(props)); -// assertEquals(expectedPassword, FederatedAuthPlugin.IDP_PASSWORD.getString(props)); -// } -// -// @Test -// public void testUsingIamHost() throws SQLException { -// IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); -// FederatedAuthPlugin spyPlugin = Mockito.spy( -// new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); -// -// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( -// mockAwsCredentialsProvider, -// Region.US_EAST_2, -// IAM_HOST, -// DEFAULT_PORT, -// DB_USER); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.plugin.TokenInfo; +import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; +import software.amazon.jdbc.plugin.iam.IamTokenUtility; +import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class FederatedAuthPluginTest { + + private static final int DEFAULT_PORT = 1234; + private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; + private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; + private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; + private static final HostSpec HOST_SPEC = + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); + private static final String DB_USER = "iamUser"; + private static final String TEST_TOKEN = "someTestToken"; + private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); + @Mock private PluginService mockPluginService; + @Mock private Dialect mockDialect; + @Mock JdbcCallable mockLambda; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryContext mockTelemetryContext; + @Mock private TelemetryCounter mockTelemetryCounter; + @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; + @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; + @Mock private RdsUtils mockRdsUtils; + @Mock private IamTokenUtility mockIamTokenUtils; + @Mock private CompletableFuture completableFuture; + @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; + private Properties props; + private AutoCloseable closeable; + + @BeforeEach + public void init() throws ExecutionException, InterruptedException, SQLException { + closeable = MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); + props.setProperty(FederatedAuthPlugin.DB_USER.name, DB_USER); + FederatedAuthPlugin.clearCache(); + + when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); + when(mockIamTokenUtils.generateAuthenticationToken( + any(AwsCredentialsProvider.class), + any(Region.class), + anyString(), + anyInt(), + anyString())).thenReturn(TEST_TOKEN); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); + when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); + when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) + .thenReturn(mockAwsCredentialsProvider); + when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); + when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); + } + + @AfterEach + public void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void testCachedToken() throws SQLException { + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testExpiredCachedToken() throws SQLException { + FederatedAuthPlugin spyPlugin = Mockito.spy( + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + String someExpiredToken = "someExpiredToken"; + TokenInfo expiredTokenInfo = new TokenInfo( + someExpiredToken, Instant.now().minusMillis(300000)); + FederatedAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, + Region.US_EAST_2, + HOST_SPEC.getHost(), + DEFAULT_PORT, + DB_USER); + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testNoCachedToken() throws SQLException { + FederatedAuthPlugin spyPlugin = Mockito.spy( + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + verify(mockIamTokenUtils).generateAuthenticationToken( + mockAwsCredentialsProvider, + Region.US_EAST_2, + HOST_SPEC.getHost(), + DEFAULT_PORT, + DB_USER); + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testSpecifiedIamHostPortRegion() throws SQLException { + final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; + final int expectedPort = 9876; + final Region expectedRegion = Region.US_WEST_2; + + props.setProperty(FederatedAuthPlugin.IAM_HOST.name, expectedHost); + props.setProperty(FederatedAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); + props.setProperty(FederatedAuthPlugin.IAM_REGION.name, expectedRegion.toString()); + + final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; + FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); + + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testIdpCredentialsFallback() throws SQLException { + String expectedUser = "expectedUser"; + String expectedPassword = "expectedPassword"; + PropertyDefinition.USER.set(props, expectedUser); + PropertyDefinition.PASSWORD.set(props, expectedPassword); + + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + assertEquals(expectedUser, FederatedAuthPlugin.IDP_USERNAME.getString(props)); + assertEquals(expectedPassword, FederatedAuthPlugin.IDP_PASSWORD.getString(props)); + } + + @Test + public void testUsingIamHost() throws SQLException { + IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); + FederatedAuthPlugin spyPlugin = Mockito.spy( + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( + mockAwsCredentialsProvider, + Region.US_EAST_2, + IAM_HOST, + DEFAULT_PORT, + DB_USER); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java index 5e22770c2..910e06fe1 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java @@ -1,220 +1,220 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.federatedauth; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyInt; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.time.Instant; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -// import software.amazon.awssdk.regions.Region; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.plugin.TokenInfo; -// import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; -// import software.amazon.jdbc.plugin.iam.IamTokenUtility; -// import software.amazon.jdbc.util.RdsUtils; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// -// class OktaAuthPluginTest { -// -// private static final int DEFAULT_PORT = 1234; -// private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; -// -// private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; -// private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; -// private static final HostSpec HOST_SPEC = -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); -// private static final String DB_USER = "iamUser"; -// private static final String TEST_TOKEN = "someTestToken"; -// private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); -// @Mock private PluginService mockPluginService; -// @Mock private Dialect mockDialect; -// @Mock JdbcCallable mockLambda; -// @Mock private TelemetryFactory mockTelemetryFactory; -// @Mock private TelemetryContext mockTelemetryContext; -// @Mock private TelemetryCounter mockTelemetryCounter; -// @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; -// @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; -// @Mock private RdsUtils mockRdsUtils; -// @Mock private IamTokenUtility mockIamTokenUtils; -// -// private Properties props; -// private AutoCloseable closeable; -// -// @BeforeEach -// void setUp() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// props = new Properties(); -// props.setProperty(PropertyDefinition.PLUGINS.name, "okta"); -// props.setProperty(OktaAuthPlugin.DB_USER.name, DB_USER); -// OktaAuthPlugin.clearCache(); -// -// when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); -// when(mockIamTokenUtils.generateAuthenticationToken( -// any(AwsCredentialsProvider.class), -// any(Region.class), -// anyString(), -// anyInt(), -// anyString())).thenReturn(TEST_TOKEN); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); -// when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); -// when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) -// .thenReturn(mockAwsCredentialsProvider); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// } -// -// @Test -// void testCachedToken() throws SQLException { -// final OktaAuthPlugin plugin = -// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); -// -// String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; -// OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); -// -// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testExpiredCachedToken() throws SQLException { -// final OktaAuthPlugin spyPlugin = -// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); -// -// final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; -// final String someExpiredToken = "someExpiredToken"; -// final TokenInfo expiredTokenInfo = new TokenInfo( -// someExpiredToken, Instant.now().minusMillis(300000)); -// OktaAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); -// -// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, -// Region.US_EAST_2, -// HOST_SPEC.getHost(), -// DEFAULT_PORT, -// DB_USER); -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testNoCachedToken() throws SQLException { -// final OktaAuthPlugin spyPlugin = -// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); -// -// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// verify(mockIamTokenUtils).generateAuthenticationToken( -// mockAwsCredentialsProvider, -// Region.US_EAST_2, -// HOST_SPEC.getHost(), -// DEFAULT_PORT, -// DB_USER); -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testSpecifiedIamHostPortRegion() throws SQLException { -// final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; -// final int expectedPort = 9876; -// final Region expectedRegion = Region.US_WEST_2; -// -// props.setProperty(OktaAuthPlugin.IAM_HOST.name, expectedHost); -// props.setProperty(OktaAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); -// props.setProperty(OktaAuthPlugin.IAM_REGION.name, expectedRegion.toString()); -// -// final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; -// OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); -// -// OktaAuthPlugin plugin = -// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); -// -// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// @Test -// void testIdpCredentialsFallback() throws SQLException { -// final String expectedUser = "expectedUser"; -// final String expectedPassword = "expectedPassword"; -// PropertyDefinition.USER.set(props, expectedUser); -// PropertyDefinition.PASSWORD.set(props, expectedPassword); -// -// final OktaAuthPlugin plugin = -// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); -// -// final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; -// OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); -// -// plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// assertEquals(expectedUser, OktaAuthPlugin.IDP_USERNAME.getString(props)); -// assertEquals(expectedPassword, OktaAuthPlugin.IDP_PASSWORD.getString(props)); -// } -// -// @Test -// public void testUsingIamHost() throws SQLException { -// IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); -// OktaAuthPlugin spyPlugin = Mockito.spy( -// new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); -// -// spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); -// -// assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( -// mockAwsCredentialsProvider, -// Region.US_EAST_2, -// IAM_HOST, -// DEFAULT_PORT, -// DB_USER); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.time.Instant; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.plugin.TokenInfo; +import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; +import software.amazon.jdbc.plugin.iam.IamTokenUtility; +import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class OktaAuthPluginTest { + + private static final int DEFAULT_PORT = 1234; + private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; + + private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; + private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; + private static final HostSpec HOST_SPEC = + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build(); + private static final String DB_USER = "iamUser"; + private static final String TEST_TOKEN = "someTestToken"; + private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); + @Mock private PluginService mockPluginService; + @Mock private Dialect mockDialect; + @Mock JdbcCallable mockLambda; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryContext mockTelemetryContext; + @Mock private TelemetryCounter mockTelemetryCounter; + @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; + @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; + @Mock private RdsUtils mockRdsUtils; + @Mock private IamTokenUtility mockIamTokenUtils; + + private Properties props; + private AutoCloseable closeable; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty(PropertyDefinition.PLUGINS.name, "okta"); + props.setProperty(OktaAuthPlugin.DB_USER.name, DB_USER); + OktaAuthPlugin.clearCache(); + + when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); + when(mockIamTokenUtils.generateAuthenticationToken( + any(AwsCredentialsProvider.class), + any(Region.class), + anyString(), + anyInt(), + anyString())).thenReturn(TEST_TOKEN); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); + when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); + when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) + .thenReturn(mockAwsCredentialsProvider); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + void testCachedToken() throws SQLException { + final OktaAuthPlugin plugin = + new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testExpiredCachedToken() throws SQLException { + final OktaAuthPlugin spyPlugin = + new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); + + final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + final String someExpiredToken = "someExpiredToken"; + final TokenInfo expiredTokenInfo = new TokenInfo( + someExpiredToken, Instant.now().minusMillis(300000)); + OktaAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, + Region.US_EAST_2, + HOST_SPEC.getHost(), + DEFAULT_PORT, + DB_USER); + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testNoCachedToken() throws SQLException { + final OktaAuthPlugin spyPlugin = + new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + verify(mockIamTokenUtils).generateAuthenticationToken( + mockAwsCredentialsProvider, + Region.US_EAST_2, + HOST_SPEC.getHost(), + DEFAULT_PORT, + DB_USER); + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testSpecifiedIamHostPortRegion() throws SQLException { + final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; + final int expectedPort = 9876; + final Region expectedRegion = Region.US_WEST_2; + + props.setProperty(OktaAuthPlugin.IAM_HOST.name, expectedHost); + props.setProperty(OktaAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); + props.setProperty(OktaAuthPlugin.IAM_REGION.name, expectedRegion.toString()); + + final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + expectedPort + ":iamUser"; + OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); + + OktaAuthPlugin plugin = + new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testIdpCredentialsFallback() throws SQLException { + final String expectedUser = "expectedUser"; + final String expectedPassword = "expectedPassword"; + PropertyDefinition.USER.set(props, expectedUser); + PropertyDefinition.PASSWORD.set(props, expectedPassword); + + final OktaAuthPlugin plugin = + new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); + + final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + assertEquals(expectedUser, OktaAuthPlugin.IDP_USERNAME.getString(props)); + assertEquals(expectedPassword, OktaAuthPlugin.IDP_PASSWORD.getString(props)); + } + + @Test + public void testUsingIamHost() throws SQLException { + IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST); + OktaAuthPlugin spyPlugin = Mockito.spy( + new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + verify(mockIamTokenUtils, times(1)).generateAuthenticationToken( + mockAwsCredentialsProvider, + Region.US_EAST_2, + IAM_HOST, + DEFAULT_PORT, + DB_USER); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java index f8153430a..a872ac96c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java @@ -1,295 +1,295 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.iam; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.anyInt; -// import static org.mockito.ArgumentMatchers.anyString; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doThrow; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import java.io.IOException; -// import java.net.HttpURLConnection; -// import java.net.URL; -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.time.Instant; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeAll; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.Mockito; -// import org.mockito.MockitoAnnotations; -// import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -// import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -// import software.amazon.awssdk.regions.Region; -// import software.amazon.jdbc.Driver; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.plugin.TokenInfo; -// import software.amazon.jdbc.util.RdsUtils; -// import software.amazon.jdbc.util.telemetry.TelemetryContext; -// import software.amazon.jdbc.util.telemetry.TelemetryCounter; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; -// -// class IamAuthConnectionPluginTest { -// -// private static final String GENERATED_TOKEN = "generatedToken"; -// private static final String TEST_TOKEN = "testToken"; -// private static final int DEFAULT_PG_PORT = 5432; -// private static final int DEFAULT_MYSQL_PORT = 3306; -// private static final String PG_CACHE_KEY = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" -// + DEFAULT_PG_PORT + ":postgresqlUser"; -// private static final String MYSQL_CACHE_KEY = "us-east-2:mysql.testdb.us-east-2.rds.amazonaws.com:" -// + DEFAULT_MYSQL_PORT + ":mysqlUser"; -// private static final String PG_DRIVER_PROTOCOL = "jdbc:postgresql:"; -// private static final String MYSQL_DRIVER_PROTOCOL = "jdbc:mysql:"; -// private static final HostSpec PG_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); -// private static final HostSpec PG_HOST_SPEC_WITH_PORT = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("pg.testdb.us-east-2.rds.amazonaws.com").port(1234).build(); -// private static final HostSpec PG_HOST_SPEC_WITH_REGION = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("pg.testdb.us-west-1.rds.amazonaws.com").build(); -// private static final HostSpec MYSQL_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("mysql.testdb.us-east-2.rds.amazonaws.com").build(); -// private Properties props; -// -// @Mock PluginService mockPluginService; -// @Mock TelemetryFactory mockTelemetryFactory; -// @Mock TelemetryCounter mockTelemetryCounter; -// @Mock TelemetryContext mockTelemetryContext; -// @Mock JdbcCallable mockLambda; -// @Mock Dialect mockDialect; -// @Mock private RdsUtils mockRdsUtils; -// @Mock private IamTokenUtility mockIamTokenUtils; -// private AutoCloseable closable; -// -// @BeforeEach -// public void init() { -// closable = MockitoAnnotations.openMocks(this); -// props = new Properties(); -// props.setProperty(PropertyDefinition.USER.name, "postgresqlUser"); -// props.setProperty(PropertyDefinition.PASSWORD.name, "postgresqlPassword"); -// props.setProperty(PropertyDefinition.PLUGINS.name, "iam"); -// IamAuthConnectionPlugin.clearCache(); -// -// when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); -// when(mockIamTokenUtils.generateAuthenticationToken( -// any(AwsCredentialsProvider.class), -// any(Region.class), -// anyString(), -// anyInt(), -// anyString())).thenReturn(GENERATED_TOKEN); -// when(mockPluginService.getDialect()).thenReturn(mockDialect); -// when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); -// when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); -// when(mockTelemetryFactory.openTelemetryContext(anyString(), eq(TelemetryTraceLevel.NESTED))).thenReturn( -// mockTelemetryContext); -// } -// -// @AfterEach -// public void cleanUp() throws Exception { -// closable.close(); -// } -// -// @BeforeAll -// public static void registerDrivers() throws SQLException { -// if (!org.postgresql.Driver.isRegistered()) { -// org.postgresql.Driver.register(); -// } -// -// if (!Driver.isRegistered()) { -// Driver.register(); -// } -// } -// -// @Test -// public void testPostgresConnectValidTokenInCache() throws SQLException { -// IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, -// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); -// -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); -// -// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); -// } -// -// @Test -// public void testMySqlConnectValidTokenInCache() throws SQLException { -// props.setProperty(PropertyDefinition.USER.name, "mysqlUser"); -// props.setProperty(PropertyDefinition.PASSWORD.name, "mysqlPassword"); -// IamAuthCacheHolder.tokenCache.put(MYSQL_CACHE_KEY, -// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); -// -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_MYSQL_PORT); -// -// testTokenSetInProps(MYSQL_DRIVER_PROTOCOL, MYSQL_HOST_SPEC); -// } -// -// @Test -// public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLException { -// final String invalidIamDefaultPort = "0"; -// props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); -// -// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" -// + PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser"; -// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, -// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); -// -// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); -// } -// -// @Test -// public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() throws SQLException { -// final String invalidIamDefaultPort = "0"; -// props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); -// -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); -// -// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" -// + DEFAULT_PG_PORT + ":postgresqlUser"; -// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, -// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); -// -// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); -// } -// -// @Test -// public void testConnectExpiredTokenInCache() throws SQLException { -// IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, -// new TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000))); -// -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); -// -// testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); -// } -// -// @Test -// public void testConnectEmptyCache() throws SQLException { -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); -// -// testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); -// } -// -// @Test -// public void testConnectWithSpecifiedPort() throws SQLException { -// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:" + "postgresqlUser"; -// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, -// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); -// -// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); -// } -// -// @Test -// public void testConnectWithSpecifiedIamDefaultPort() throws SQLException { -// final String iamDefaultPort = "9999"; -// props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, iamDefaultPort); -// final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" -// + iamDefaultPort + ":postgresqlUser"; -// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, -// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); -// -// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); -// } -// -// @Test -// public void testConnectWithSpecifiedRegion() throws SQLException { -// final String cacheKeyWithNewRegion = -// "us-west-1:pg.testdb.us-west-1.rds.amazonaws.com:" + DEFAULT_PG_PORT + ":" + "postgresqlUser"; -// props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-west-1"); -// IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewRegion, -// new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); -// -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); -// -// testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_REGION); -// } -// -// @Test -// public void testConnectWithSpecifiedHost() throws SQLException { -// props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-east-2"); -// props.setProperty(IamAuthConnectionPlugin.IAM_HOST.name, "pg.testdb.us-east-2.rds.amazonaws.com"); -// -// when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); -// -// testGenerateToken( -// PG_DRIVER_PROTOCOL, -// new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("8.8.8.8").build(), -// "pg.testdb.us-east-2.rds.amazonaws.com"); -// } -// -// @Test -// public void testAwsSupportedRegionsUrlExists() throws IOException { -// final URL url = -// new URL("https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html"); -// final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection(); -// final int responseCode = urlConnection.getResponseCode(); -// -// assertEquals(HttpURLConnection.HTTP_OK, responseCode); -// } -// -// public void testTokenSetInProps(final String protocol, final HostSpec hostSpec) throws SQLException { -// -// IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); -// doThrow(new SQLException()).when(mockLambda).call(); -// -// assertThrows(SQLException.class, () -> targetPlugin.connect(protocol, hostSpec, props, true, mockLambda)); -// verify(mockLambda, times(1)).call(); -// -// assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// } -// -// private void testGenerateToken(final String protocol, final HostSpec hostSpec) throws SQLException { -// testGenerateToken(protocol, hostSpec, hostSpec.getHost()); -// } -// -// private void testGenerateToken( -// final String protocol, -// final HostSpec hostSpec, -// final String expectedHost) throws SQLException { -// final IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); -// final IamAuthConnectionPlugin spyPlugin = Mockito.spy(targetPlugin); -// -// doThrow(new SQLException()).when(mockLambda).call(); -// -// assertThrows(SQLException.class, -// () -> spyPlugin.connect(protocol, hostSpec, props, true, mockLambda)); -// -// verify(mockIamTokenUtils).generateAuthenticationToken( -// any(DefaultCredentialsProvider.class), -// eq(Region.US_EAST_2), -// eq(expectedHost), -// eq(DEFAULT_PG_PORT), -// eq("postgresqlUser")); -// verify(mockLambda, times(1)).call(); -// -// assertEquals(GENERATED_TOKEN, PropertyDefinition.PASSWORD.getString(props)); -// assertEquals(GENERATED_TOKEN, IamAuthCacheHolder.tokenCache.get(PG_CACHE_KEY).getToken()); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.URL; +import java.sql.Connection; +import java.sql.SQLException; +import java.time.Instant; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.Driver; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.plugin.TokenInfo; +import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +class IamAuthConnectionPluginTest { + + private static final String GENERATED_TOKEN = "generatedToken"; + private static final String TEST_TOKEN = "testToken"; + private static final int DEFAULT_PG_PORT = 5432; + private static final int DEFAULT_MYSQL_PORT = 3306; + private static final String PG_CACHE_KEY = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + + DEFAULT_PG_PORT + ":postgresqlUser"; + private static final String MYSQL_CACHE_KEY = "us-east-2:mysql.testdb.us-east-2.rds.amazonaws.com:" + + DEFAULT_MYSQL_PORT + ":mysqlUser"; + private static final String PG_DRIVER_PROTOCOL = "jdbc:postgresql:"; + private static final String MYSQL_DRIVER_PROTOCOL = "jdbc:mysql:"; + private static final HostSpec PG_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); + private static final HostSpec PG_HOST_SPEC_WITH_PORT = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("pg.testdb.us-east-2.rds.amazonaws.com").port(1234).build(); + private static final HostSpec PG_HOST_SPEC_WITH_REGION = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("pg.testdb.us-west-1.rds.amazonaws.com").build(); + private static final HostSpec MYSQL_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("mysql.testdb.us-east-2.rds.amazonaws.com").build(); + private Properties props; + + @Mock PluginService mockPluginService; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TelemetryCounter mockTelemetryCounter; + @Mock TelemetryContext mockTelemetryContext; + @Mock JdbcCallable mockLambda; + @Mock Dialect mockDialect; + @Mock private RdsUtils mockRdsUtils; + @Mock private IamTokenUtility mockIamTokenUtils; + private AutoCloseable closable; + + @BeforeEach + public void init() { + closable = MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty(PropertyDefinition.USER.name, "postgresqlUser"); + props.setProperty(PropertyDefinition.PASSWORD.name, "postgresqlPassword"); + props.setProperty(PropertyDefinition.PLUGINS.name, "iam"); + IamAuthConnectionPlugin.clearCache(); + + when(mockRdsUtils.getRdsRegion(anyString())).thenReturn("us-east-2"); + when(mockIamTokenUtils.generateAuthenticationToken( + any(AwsCredentialsProvider.class), + any(Region.class), + anyString(), + anyInt(), + anyString())).thenReturn(GENERATED_TOKEN); + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); + when(mockTelemetryFactory.openTelemetryContext(anyString(), eq(TelemetryTraceLevel.NESTED))).thenReturn( + mockTelemetryContext); + } + + @AfterEach + public void cleanUp() throws Exception { + closable.close(); + } + + @BeforeAll + public static void registerDrivers() throws SQLException { + if (!org.postgresql.Driver.isRegistered()) { + org.postgresql.Driver.register(); + } + + if (!Driver.isRegistered()) { + Driver.register(); + } + } + + @Test + public void testPostgresConnectValidTokenInCache() throws SQLException { + IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); + } + + @Test + public void testMySqlConnectValidTokenInCache() throws SQLException { + props.setProperty(PropertyDefinition.USER.name, "mysqlUser"); + props.setProperty(PropertyDefinition.PASSWORD.name, "mysqlPassword"); + IamAuthCacheHolder.tokenCache.put(MYSQL_CACHE_KEY, + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_MYSQL_PORT); + + testTokenSetInProps(MYSQL_DRIVER_PROTOCOL, MYSQL_HOST_SPEC); + } + + @Test + public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLException { + final String invalidIamDefaultPort = "0"; + props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); + + final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + + PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser"; + IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); + } + + @Test + public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() throws SQLException { + final String invalidIamDefaultPort = "0"; + props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); + + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); + + final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + + DEFAULT_PG_PORT + ":postgresqlUser"; + IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); + } + + @Test + public void testConnectExpiredTokenInCache() throws SQLException { + IamAuthCacheHolder.tokenCache.put(PG_CACHE_KEY, + new TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000))); + + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); + + testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); + } + + @Test + public void testConnectEmptyCache() throws SQLException { + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); + + testGenerateToken(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); + } + + @Test + public void testConnectWithSpecifiedPort() throws SQLException { + final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:" + "postgresqlUser"; + IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); + } + + @Test + public void testConnectWithSpecifiedIamDefaultPort() throws SQLException { + final String iamDefaultPort = "9999"; + props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, iamDefaultPort); + final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + + iamDefaultPort + ":postgresqlUser"; + IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewPort, + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); + } + + @Test + public void testConnectWithSpecifiedRegion() throws SQLException { + final String cacheKeyWithNewRegion = + "us-west-1:pg.testdb.us-west-1.rds.amazonaws.com:" + DEFAULT_PG_PORT + ":" + "postgresqlUser"; + props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-west-1"); + IamAuthCacheHolder.tokenCache.put(cacheKeyWithNewRegion, + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_REGION); + } + + @Test + public void testConnectWithSpecifiedHost() throws SQLException { + props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-east-2"); + props.setProperty(IamAuthConnectionPlugin.IAM_HOST.name, "pg.testdb.us-east-2.rds.amazonaws.com"); + + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); + + testGenerateToken( + PG_DRIVER_PROTOCOL, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("8.8.8.8").build(), + "pg.testdb.us-east-2.rds.amazonaws.com"); + } + + @Test + public void testAwsSupportedRegionsUrlExists() throws IOException { + final URL url = + new URL("https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html"); + final HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection(); + final int responseCode = urlConnection.getResponseCode(); + + assertEquals(HttpURLConnection.HTTP_OK, responseCode); + } + + public void testTokenSetInProps(final String protocol, final HostSpec hostSpec) throws SQLException { + + IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); + doThrow(new SQLException()).when(mockLambda).call(); + + assertThrows(SQLException.class, () -> targetPlugin.connect(protocol, hostSpec, props, true, mockLambda)); + verify(mockLambda, times(1)).call(); + + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + private void testGenerateToken(final String protocol, final HostSpec hostSpec) throws SQLException { + testGenerateToken(protocol, hostSpec, hostSpec.getHost()); + } + + private void testGenerateToken( + final String protocol, + final HostSpec hostSpec, + final String expectedHost) throws SQLException { + final IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); + final IamAuthConnectionPlugin spyPlugin = Mockito.spy(targetPlugin); + + doThrow(new SQLException()).when(mockLambda).call(); + + assertThrows(SQLException.class, + () -> spyPlugin.connect(protocol, hostSpec, props, true, mockLambda)); + + verify(mockIamTokenUtils).generateAuthenticationToken( + any(DefaultCredentialsProvider.class), + eq(Region.US_EAST_2), + eq(expectedHost), + eq(DEFAULT_PG_PORT), + eq("postgresqlUser")); + verify(mockLambda, times(1)).call(); + + assertEquals(GENERATED_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + assertEquals(GENERATED_TOKEN, IamAuthCacheHolder.tokenCache.get(PG_CACHE_KEY).getToken()); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java index 3aa1b40de..411233100 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java @@ -1,162 +1,162 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.limitless; -// -// import static org.junit.Assert.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.Mockito.doAnswer; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// import static software.amazon.jdbc.plugin.limitless.LimitlessConnectionPlugin.INTERVAL_MILLIS; -// -// import java.sql.Connection; -// import java.sql.SQLException; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import org.mockito.invocation.InvocationOnMock; -// import org.mockito.stubbing.Answer; -// import software.amazon.jdbc.HostListProvider; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.dialect.AuroraPgDialect; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.dialect.PgDialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// -// public class LimitlessConnectionPluginTest { -// -// private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; -// private static final HostSpec INPUT_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); -// private static final String CLUSTER_ID = "someClusterId"; -// -// private static final HostSpec expectedSelectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("expected-selected-instance").role(HostRole.WRITER).weight(Long.MAX_VALUE).build(); -// private static final Dialect supportedDialect = new AuroraPgDialect(); -// @Mock JdbcCallable mockConnectFuncLambda; -// @Mock private Connection mockConnection; -// @Mock private PluginService mockPluginService; -// @Mock private HostListProvider mockHostListProvider; -// @Mock private LimitlessRouterService mockLimitlessRouterService; -// private static Properties props; -// -// private static LimitlessConnectionPlugin plugin; -// -// private AutoCloseable closeable; -// -// @BeforeEach -// public void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// props = new Properties(); -// plugin = new LimitlessConnectionPlugin(mockPluginService, props, () -> mockLimitlessRouterService); -// -// when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); -// when(mockPluginService.getDialect()).thenReturn(supportedDialect); -// when(mockHostListProvider.getClusterId()).thenReturn(CLUSTER_ID); -// when(mockConnectFuncLambda.call()).thenReturn(mockConnection); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// } -// -// @Test -// void testConnect() throws SQLException { -// doAnswer(new Answer() { -// public Void answer(InvocationOnMock invocation) { -// LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; -// context.setConnection(mockConnection); -// return null; -// } -// }).when(mockLimitlessRouterService).establishConnection(any()); -// -// final Connection expectedConnection = mockConnection; -// final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, -// mockConnectFuncLambda); -// -// assertEquals(expectedConnection, actualConnection); -// verify(mockPluginService, times(1)).getDialect(); -// verify(mockConnectFuncLambda, times(0)).call(); -// verify(mockLimitlessRouterService, times(1)) -// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); -// verify(mockLimitlessRouterService, times(1)).establishConnection(any()); -// } -// -// @Test -// void testConnectGivenNullConnection() throws SQLException { -// doAnswer(new Answer() { -// public Void answer(InvocationOnMock invocation) { -// LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; -// context.setConnection(null); -// return null; -// } -// }).when(mockLimitlessRouterService).establishConnection(any()); -// -// assertThrows( -// SQLException.class, -// () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); -// -// verify(mockPluginService, times(1)).getDialect(); -// verify(mockConnectFuncLambda, times(0)).call(); -// verify(mockLimitlessRouterService, times(1)) -// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); -// verify(mockLimitlessRouterService, times(1)).establishConnection(any()); -// } -// -// @Test -// void testConnectGivenUnsupportedDialect() throws SQLException { -// final Dialect unsupportedDialect = new PgDialect(); -// when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, unsupportedDialect); -// -// assertThrows( -// UnsupportedOperationException.class, -// () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); -// -// verify(mockPluginService, times(2)).getDialect(); -// verify(mockConnectFuncLambda, times(1)).call(); -// verify(mockLimitlessRouterService, times(0)) -// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); -// verify(mockLimitlessRouterService, times(0)).establishConnection(any()); -// } -// -// @Test -// void testConnectGivenSupportedDialectAfterRefresh() throws SQLException { -// final Dialect unsupportedDialect = new PgDialect(); -// when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, supportedDialect); -// -// final Connection expectedConnection = mockConnection; -// final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, -// mockConnectFuncLambda); -// -// assertEquals(expectedConnection, actualConnection); -// verify(mockPluginService, times(2)).getDialect(); -// verify(mockConnectFuncLambda, times(1)).call(); -// verify(mockLimitlessRouterService, times(1)) -// .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); -// verify(mockLimitlessRouterService, times(1)).establishConnection(any()); -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.limitless; + +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.jdbc.plugin.limitless.LimitlessConnectionPlugin.INTERVAL_MILLIS; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import software.amazon.jdbc.HostListProvider; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.dialect.AuroraPgDialect; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.dialect.PgDialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; + +public class LimitlessConnectionPluginTest { + + private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; + private static final HostSpec INPUT_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); + private static final String CLUSTER_ID = "someClusterId"; + + private static final HostSpec expectedSelectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("expected-selected-instance").role(HostRole.WRITER).weight(Long.MAX_VALUE).build(); + private static final Dialect supportedDialect = new AuroraPgDialect(); + @Mock JdbcCallable mockConnectFuncLambda; + @Mock private Connection mockConnection; + @Mock private PluginService mockPluginService; + @Mock private HostListProvider mockHostListProvider; + @Mock private LimitlessRouterService mockLimitlessRouterService; + private static Properties props; + + private static LimitlessConnectionPlugin plugin; + + private AutoCloseable closeable; + + @BeforeEach + public void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + props = new Properties(); + plugin = new LimitlessConnectionPlugin(mockPluginService, props, () -> mockLimitlessRouterService); + + when(mockPluginService.getHostListProvider()).thenReturn(mockHostListProvider); + when(mockPluginService.getDialect()).thenReturn(supportedDialect); + when(mockHostListProvider.getClusterId()).thenReturn(CLUSTER_ID); + when(mockConnectFuncLambda.call()).thenReturn(mockConnection); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void testConnect() throws SQLException { + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocation) { + LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; + context.setConnection(mockConnection); + return null; + } + }).when(mockLimitlessRouterService).establishConnection(any()); + + final Connection expectedConnection = mockConnection; + final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, + mockConnectFuncLambda); + + assertEquals(expectedConnection, actualConnection); + verify(mockPluginService, times(1)).getDialect(); + verify(mockConnectFuncLambda, times(0)).call(); + verify(mockLimitlessRouterService, times(1)) + .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); + verify(mockLimitlessRouterService, times(1)).establishConnection(any()); + } + + @Test + void testConnectGivenNullConnection() throws SQLException { + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocation) { + LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; + context.setConnection(null); + return null; + } + }).when(mockLimitlessRouterService).establishConnection(any()); + + assertThrows( + SQLException.class, + () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); + + verify(mockPluginService, times(1)).getDialect(); + verify(mockConnectFuncLambda, times(0)).call(); + verify(mockLimitlessRouterService, times(1)) + .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); + verify(mockLimitlessRouterService, times(1)).establishConnection(any()); + } + + @Test + void testConnectGivenUnsupportedDialect() throws SQLException { + final Dialect unsupportedDialect = new PgDialect(); + when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, unsupportedDialect); + + assertThrows( + UnsupportedOperationException.class, + () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); + + verify(mockPluginService, times(2)).getDialect(); + verify(mockConnectFuncLambda, times(1)).call(); + verify(mockLimitlessRouterService, times(0)) + .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); + verify(mockLimitlessRouterService, times(0)).establishConnection(any()); + } + + @Test + void testConnectGivenSupportedDialectAfterRefresh() throws SQLException { + final Dialect unsupportedDialect = new PgDialect(); + when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, supportedDialect); + + final Connection expectedConnection = mockConnection; + final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, + mockConnectFuncLambda); + + assertEquals(expectedConnection, actualConnection); + verify(mockPluginService, times(2)).getDialect(); + verify(mockConnectFuncLambda, times(1)).call(); + verify(mockLimitlessRouterService, times(1)) + .startMonitoring(INPUT_HOST_SPEC, props, INTERVAL_MILLIS.getInteger(props)); + verify(mockLimitlessRouterService, times(1)).establishConnection(any()); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index 1dcb8e62c..c7c7bdc1b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -1,626 +1,626 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.plugin.readwritesplitting; -// -// import static org.junit.Assert.assertThrows; -// import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertNull; -// import static org.mockito.AdditionalMatchers.not; -// import static org.mockito.ArgumentMatchers.any; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.never; -// import static org.mockito.Mockito.spy; -// import static org.mockito.Mockito.times; -// import static org.mockito.Mockito.verify; -// import static org.mockito.Mockito.when; -// -// import com.zaxxer.hikari.HikariConfig; -// import java.sql.Connection; -// import java.sql.ResultSet; -// import java.sql.SQLException; -// import java.sql.Statement; -// import java.util.Arrays; -// import java.util.Collections; -// import java.util.EnumSet; -// import java.util.List; -// import java.util.Properties; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.HostListProviderService; -// import software.amazon.jdbc.HostRole; -// import software.amazon.jdbc.HostSpec; -// import software.amazon.jdbc.HostSpecBuilder; -// import software.amazon.jdbc.JdbcCallable; -// import software.amazon.jdbc.NodeChangeOptions; -// import software.amazon.jdbc.OldConnectionSuggestedAction; -// import software.amazon.jdbc.PluginService; -// import software.amazon.jdbc.PropertyDefinition; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -// import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; -// import software.amazon.jdbc.util.SqlState; -// -// public class ReadWriteSplittingPluginTest { -// private static final String TEST_PROTOCOL = "jdbc:postgresql:"; -// private static final int TEST_PORT = 5432; -// private static final Properties defaultProps = new Properties(); -// -// private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-0").port(TEST_PORT).build(); -// private final HostSpec readerHostSpec1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-1").port(TEST_PORT).role(HostRole.READER).build(); -// private final HostSpec readerHostSpec2 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-2").port(TEST_PORT).role(HostRole.READER).build(); -// private final HostSpec readerHostSpec3 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-3").port(TEST_PORT).role(HostRole.READER).build(); -// private final HostSpec readerHostSpecWithIncorrectRole = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("instance-1").port(TEST_PORT).role(HostRole.WRITER).build(); -// private final HostSpec instanceUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("jdbc:aws-wrapper:postgresql://my-instance-name.XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT) -// .build(); -// private final HostSpec ipUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("10.10.10.10").port(TEST_PORT).build(); -// private final HostSpec clusterUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) -// .host("my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT).build(); -// private final List defaultHosts = Arrays.asList( -// writerHostSpec, -// readerHostSpec1, -// readerHostSpec2, -// readerHostSpec3); -// private final List singleReaderTopology = Arrays.asList( -// writerHostSpec, -// readerHostSpec1); -// -// private AutoCloseable closeable; -// -// @Mock private JdbcCallable mockConnectFunc; -// @Mock private JdbcCallable mockSqlFunction; -// @Mock private PluginService mockPluginService; -// @Mock private Dialect mockDialect; -// @Mock private HostListProviderService mockHostListProviderService; -// @Mock private Connection mockWriterConn; -// @Mock private Connection mockNewWriterConn; -// @Mock private Connection mockClosedWriterConn; -// @Mock private Connection mockReaderConn1; -// @Mock private Connection mockReaderConn2; -// @Mock private Connection mockReaderConn3; -// @Mock private Statement mockStatement; -// @Mock private ResultSet mockResultSet; -// @Mock private EnumSet mockChanges; -// -// @BeforeEach -// public void init() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// mockDefaultBehavior(); -// } -// -// @AfterEach -// void cleanUp() throws Exception { -// closeable.close(); -// defaultProps.clear(); -// } -// -// void mockDefaultBehavior() throws SQLException { -// when(this.mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); -// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); -// when(this.mockPluginService.getAllHosts()).thenReturn(defaultHosts); -// when(this.mockPluginService.getHosts()).thenReturn(defaultHosts); -// when(this.mockPluginService.getHostSpecByStrategy(eq(HostRole.READER), eq("random"))) -// .thenReturn(readerHostSpec1); -// when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class))) -// .thenReturn(mockWriterConn); -// when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class), any())) -// .thenReturn(mockWriterConn); -// when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(writerHostSpec); -// when(this.mockPluginService.getHostRole(mockWriterConn)).thenReturn(HostRole.WRITER); -// when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(HostRole.READER); -// when(this.mockPluginService.getHostRole(mockReaderConn2)).thenReturn(HostRole.READER); -// when(this.mockPluginService.getHostRole(mockReaderConn3)).thenReturn(HostRole.READER); -// when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class))) -// .thenReturn(mockReaderConn1); -// when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class), any())) -// .thenReturn(mockReaderConn1); -// when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class))) -// .thenReturn(mockReaderConn2); -// when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class), any())) -// .thenReturn(mockReaderConn2); -// when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class))) -// .thenReturn(mockReaderConn3); -// when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class), any())) -// .thenReturn(mockReaderConn3); -// when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); -// when(this.mockConnectFunc.call()).thenReturn(mockWriterConn); -// when(mockWriterConn.createStatement()).thenReturn(mockStatement); -// when(mockReaderConn1.createStatement()).thenReturn(mockStatement); -// when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); -// when(mockResultSet.next()).thenReturn(true); -// when(mockClosedWriterConn.isClosed()).thenReturn(true); -// } -// -// @Test -// public void testSetReadOnly_trueFalse() throws SQLException { -// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// null); -// plugin.switchConnectionIfRequired(true); -// -// verify(mockPluginService, times(1)) -// .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); -// verify(mockPluginService, times(0)) -// .setCurrentConnection(eq(mockWriterConn), any(HostSpec.class)); -// assertEquals(mockReaderConn1, plugin.getReaderConnection()); -// assertEquals(mockWriterConn, plugin.getWriterConnection()); -// -// when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); -// -// plugin.switchConnectionIfRequired(false); -// -// verify(mockPluginService, times(1)) -// .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); -// verify(mockPluginService, times(1)) -// .setCurrentConnection(eq(mockWriterConn), eq(writerHostSpec)); -// assertEquals(mockReaderConn1, plugin.getReaderConnection()); -// assertEquals(mockWriterConn, plugin.getWriterConnection()); -// } -// -// @Test -// public void testSetReadOnlyTrue_alreadyOnReader() throws SQLException { -// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// mockReaderConn1); -// -// plugin.switchConnectionIfRequired(true); -// -// verify(mockPluginService, times(0)) -// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); -// assertEquals(mockReaderConn1, plugin.getReaderConnection()); -// assertNull(plugin.getWriterConnection()); -// } -// -// @Test -// public void testSetReadOnlyFalse_alreadyOnWriter() throws SQLException { -// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// null); -// plugin.switchConnectionIfRequired(false); -// -// verify(mockPluginService, times(0)) -// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); -// assertEquals(mockWriterConn, plugin.getWriterConnection()); -// assertNull(plugin.getReaderConnection()); -// } -// -// @Test -// public void testSetReadOnly_falseInTransaction() { -// when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); -// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); -// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); -// when(mockPluginService.isInTransaction()).thenReturn(true); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// mockReaderConn1); -// -// final SQLException e = -// assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); -// assertEquals(SqlState.ACTIVE_SQL_TRANSACTION.getState(), e.getSQLState()); -// } -// -// @Test -// public void testSetReadOnly_true() throws SQLException { -// final ReadWriteSplittingPlugin plugin = -// new ReadWriteSplittingPlugin(mockPluginService, defaultProps); -// plugin.switchConnectionIfRequired(true); -// -// assertEquals(mockReaderConn1, plugin.getReaderConnection()); -// } -// -// @Test -// public void testSetReadOnly_false() throws SQLException { -// when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); -// when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// mockReaderConn1); -// plugin.switchConnectionIfRequired(false); -// -// assertEquals(mockWriterConn, plugin.getWriterConnection()); -// } -// -// @Test -// public void testSetReadOnly_true_oneHost() throws SQLException { -// when(this.mockPluginService.getHosts()).thenReturn(Collections.singletonList(writerHostSpec)); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// null); -// plugin.switchConnectionIfRequired(true); -// -// verify(mockPluginService, times(0)) -// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); -// assertEquals(mockWriterConn, plugin.getWriterConnection()); -// assertNull(plugin.getReaderConnection()); -// } -// -// @Test -// public void testSetReadOnly_false_writerConnectionFails() throws SQLException { -// when(mockPluginService.connect(eq(writerHostSpec), eq(defaultProps), any())) -// .thenThrow(SQLException.class); -// when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); -// when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// mockReaderConn1); -// -// final SQLException e = -// assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); -// assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), e.getSQLState()); -// verify(mockPluginService, times(0)) -// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); -// } -// -// @Test -// public void testSetReadOnly_true_readerConnectionFailed() throws SQLException { -// when(this.mockPluginService.connect(eq(readerHostSpec1), eq(defaultProps), any())) -// .thenThrow(SQLException.class); -// when(this.mockPluginService.connect(eq(readerHostSpec2), eq(defaultProps), any())) -// .thenThrow(SQLException.class); -// when(this.mockPluginService.connect(eq(readerHostSpec3), eq(defaultProps), any())) -// .thenThrow(SQLException.class); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// null); -// plugin.switchConnectionIfRequired(true); -// -// verify(mockPluginService, times(0)) -// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); -// assertNull(plugin.getReaderConnection()); -// } -// -// @Test -// public void testSetReadOnlyOnClosedConnection() throws SQLException { -// when(mockPluginService.getCurrentConnection()).thenReturn(mockClosedWriterConn); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockClosedWriterConn, -// null); -// -// final SQLException e = -// assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(true)); -// assertEquals(SqlState.CONNECTION_NOT_OPEN.getState(), e.getSQLState()); -// verify(mockPluginService, times(0)) -// .setCurrentConnection(any(Connection.class), any(HostSpec.class)); -// assertNull(plugin.getReaderConnection()); -// } -// -// @Test -// public void testExecute_failoverToNewWriter() throws SQLException { -// when(mockSqlFunction.call()).thenThrow(FailoverSuccessSQLException.class); -// when(mockPluginService.getCurrentConnection()).thenReturn(mockNewWriterConn); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// null); -// -// assertThrows( -// SQLException.class, -// () -> plugin.execute( -// ResultSet.class, -// SQLException.class, -// mockStatement, -// "Statement.executeQuery", -// mockSqlFunction, -// new Object[] { -// "begin"})); -// verify(mockWriterConn, times(1)).close(); -// } -// -// @Test -// public void testNotifyConnectionChange() { -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// null); -// -// final OldConnectionSuggestedAction suggestion = plugin.notifyConnectionChanged(mockChanges); -// -// assertEquals(mockWriterConn, plugin.getWriterConnection()); -// assertEquals(OldConnectionSuggestedAction.NO_OPINION, suggestion); -// } -// -// @Test -// public void testConnectNonInitialConnection() throws SQLException { -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// null); -// -// final Connection connection = -// plugin.connect(TEST_PROTOCOL, writerHostSpec, defaultProps, false, this.mockConnectFunc); -// -// assertEquals(mockWriterConn, connection); -// verify(mockConnectFunc).call(); -// verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); -// } -// -// @Test -// public void testConnectRdsInstanceUrl() throws SQLException { -// when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); -// when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// null); -// final Connection connection = plugin.connect( -// TEST_PROTOCOL, -// instanceUrlHostSpec, -// defaultProps, -// true, -// this.mockConnectFunc); -// -// assertEquals(mockReaderConn1, connection); -// verify(mockConnectFunc).call(); -// verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); -// } -// -// @Test -// public void testConnectReaderIpUrl() throws SQLException { -// when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); -// when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// null); -// final Connection connection = -// plugin.connect(TEST_PROTOCOL, ipUrlHostSpec, defaultProps, true, this.mockConnectFunc); -// -// assertEquals(mockReaderConn1, connection); -// verify(mockConnectFunc).call(); -// verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); -// } -// -// @Test -// public void testConnectClusterUrl() throws SQLException { -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// null); -// final Connection connection = -// plugin.connect(TEST_PROTOCOL, clusterUrlHostSpec, defaultProps, true, this.mockConnectFunc); -// -// assertEquals(mockWriterConn, connection); -// verify(mockConnectFunc).call(); -// verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); -// } -// -// @Test -// public void testConnect_errorUpdatingHostSpec() throws SQLException { -// when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); -// when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(null); -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// null); -// -// assertThrows( -// SQLException.class, -// () -> plugin.connect( -// TEST_PROTOCOL, -// ipUrlHostSpec, -// defaultProps, -// true, -// this.mockConnectFunc)); -// verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); -// } -// -// @Test -// public void testExecuteClearWarnings() throws SQLException { -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// mockReaderConn1); -// -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// mockStatement, -// "Connection.clearWarnings", -// mockSqlFunction, -// new Object[] {} -// ); -// verify(mockWriterConn, times(1)).clearWarnings(); -// verify(mockReaderConn1, times(1)).clearWarnings(); -// } -// -// @Test -// public void testExecuteClearWarningsOnClosedConnectionsIsNotCalled() throws SQLException { -// when(mockWriterConn.isClosed()).thenReturn(true); -// when(mockReaderConn1.isClosed()).thenReturn(true); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// mockReaderConn1); -// -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// mockStatement, -// "Connection.clearWarnings", -// mockSqlFunction, -// new Object[] {} -// ); -// verify(mockWriterConn, never()).clearWarnings(); -// verify(mockReaderConn1, never()).clearWarnings(); -// } -// -// @Test -// public void testExecuteClearWarningsOnNullConnectionsIsNotCalled() throws SQLException { -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// null); -// -// // calling clearWarnings() on nullified connection would throw an exception -// assertDoesNotThrow(() -> { -// plugin.execute( -// ResultSet.class, -// SQLException.class, -// mockStatement, -// "Connection.clearWarnings", -// mockSqlFunction, -// new Object[] {} -// ); -// }); -// } -// -// @Test -// public void testClosePooledReaderConnectionAfterSetReadOnly() throws SQLException { -// doReturn(writerHostSpec) -// .doReturn(writerHostSpec) -// .doReturn(readerHostSpec1) -// .when(this.mockPluginService).getCurrentHostSpec(); -// doReturn(mockReaderConn1).when(mockPluginService).connect(readerHostSpec1, null); -// when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); -// when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// mockWriterConn, -// null); -// final ReadWriteSplittingPlugin spyPlugin = spy(plugin); -// -// spyPlugin.switchConnectionIfRequired(true); -// spyPlugin.switchConnectionIfRequired(false); -// -// verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockReaderConn1)); -// } -// -// @Test -// public void testClosePooledWriterConnectionAfterSetReadOnly() throws SQLException { -// doReturn(writerHostSpec) -// .doReturn(writerHostSpec) -// .doReturn(readerHostSpec1) -// .doReturn(readerHostSpec1) -// .doReturn(writerHostSpec) -// .when(this.mockPluginService).getCurrentHostSpec(); -// doReturn(mockWriterConn).when(mockPluginService).connect(writerHostSpec, null); -// when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); -// when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); -// -// final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( -// mockPluginService, -// defaultProps, -// mockHostListProviderService, -// null, -// null); -// final ReadWriteSplittingPlugin spyPlugin = spy(plugin); -// -// spyPlugin.switchConnectionIfRequired(true); -// spyPlugin.switchConnectionIfRequired(false); -// spyPlugin.switchConnectionIfRequired(true); -// -// verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockWriterConn)); -// } -// -// private static HikariConfig getHikariConfig(HostSpec hostSpec, Properties props) { -// final HikariConfig config = new HikariConfig(); -// config.setMaximumPoolSize(3); -// config.setInitializationFailTimeout(75000); -// config.setConnectionTimeout(10000); -// return config; -// } -// -// private static String getPoolKey(HostSpec hostSpec, Properties props) { -// final String user = props.getProperty(PropertyDefinition.USER.name); -// final String somePropertyValue = props.getProperty("somePropertyValue"); -// return hostSpec.getUrl() + user + somePropertyValue; -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.readwritesplitting; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.AdditionalMatchers.not; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.zaxxer.hikari.HikariConfig; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Properties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.HostListProviderService; +import software.amazon.jdbc.HostRole; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.NodeChangeOptions; +import software.amazon.jdbc.OldConnectionSuggestedAction; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; +import software.amazon.jdbc.util.SqlState; + +public class ReadWriteSplittingPluginTest { + private static final String TEST_PROTOCOL = "jdbc:postgresql:"; + private static final int TEST_PORT = 5432; + private static final Properties defaultProps = new Properties(); + + private final HostSpec writerHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-0").port(TEST_PORT).build(); + private final HostSpec readerHostSpec1 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-1").port(TEST_PORT).role(HostRole.READER).build(); + private final HostSpec readerHostSpec2 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-2").port(TEST_PORT).role(HostRole.READER).build(); + private final HostSpec readerHostSpec3 = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-3").port(TEST_PORT).role(HostRole.READER).build(); + private final HostSpec readerHostSpecWithIncorrectRole = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("instance-1").port(TEST_PORT).role(HostRole.WRITER).build(); + private final HostSpec instanceUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("jdbc:aws-wrapper:postgresql://my-instance-name.XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT) + .build(); + private final HostSpec ipUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("10.10.10.10").port(TEST_PORT).build(); + private final HostSpec clusterUrlHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com").port(TEST_PORT).build(); + private final List defaultHosts = Arrays.asList( + writerHostSpec, + readerHostSpec1, + readerHostSpec2, + readerHostSpec3); + private final List singleReaderTopology = Arrays.asList( + writerHostSpec, + readerHostSpec1); + + private AutoCloseable closeable; + + @Mock private JdbcCallable mockConnectFunc; + @Mock private JdbcCallable mockSqlFunction; + @Mock private PluginService mockPluginService; + @Mock private Dialect mockDialect; + @Mock private HostListProviderService mockHostListProviderService; + @Mock private Connection mockWriterConn; + @Mock private Connection mockNewWriterConn; + @Mock private Connection mockClosedWriterConn; + @Mock private Connection mockReaderConn1; + @Mock private Connection mockReaderConn2; + @Mock private Connection mockReaderConn3; + @Mock private Statement mockStatement; + @Mock private ResultSet mockResultSet; + @Mock private EnumSet mockChanges; + + @BeforeEach + public void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + mockDefaultBehavior(); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + defaultProps.clear(); + } + + void mockDefaultBehavior() throws SQLException { + when(this.mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); + when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); + when(this.mockPluginService.getAllHosts()).thenReturn(defaultHosts); + when(this.mockPluginService.getHosts()).thenReturn(defaultHosts); + when(this.mockPluginService.getHostSpecByStrategy(eq(HostRole.READER), eq("random"))) + .thenReturn(readerHostSpec1); + when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class))) + .thenReturn(mockWriterConn); + when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class), any())) + .thenReturn(mockWriterConn); + when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(writerHostSpec); + when(this.mockPluginService.getHostRole(mockWriterConn)).thenReturn(HostRole.WRITER); + when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(HostRole.READER); + when(this.mockPluginService.getHostRole(mockReaderConn2)).thenReturn(HostRole.READER); + when(this.mockPluginService.getHostRole(mockReaderConn3)).thenReturn(HostRole.READER); + when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class))) + .thenReturn(mockReaderConn1); + when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class), any())) + .thenReturn(mockReaderConn1); + when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class))) + .thenReturn(mockReaderConn2); + when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class), any())) + .thenReturn(mockReaderConn2); + when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class))) + .thenReturn(mockReaderConn3); + when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class), any())) + .thenReturn(mockReaderConn3); + when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); + when(this.mockConnectFunc.call()).thenReturn(mockWriterConn); + when(mockWriterConn.createStatement()).thenReturn(mockStatement); + when(mockReaderConn1.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + when(mockClosedWriterConn.isClosed()).thenReturn(true); + } + + @Test + public void testSetReadOnly_trueFalse() throws SQLException { + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + null); + plugin.switchConnectionIfRequired(true); + + verify(mockPluginService, times(1)) + .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); + verify(mockPluginService, times(0)) + .setCurrentConnection(eq(mockWriterConn), any(HostSpec.class)); + assertEquals(mockReaderConn1, plugin.getReaderConnection()); + assertEquals(mockWriterConn, plugin.getWriterConnection()); + + when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); + when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); + + plugin.switchConnectionIfRequired(false); + + verify(mockPluginService, times(1)) + .setCurrentConnection(eq(mockReaderConn1), not(eq(writerHostSpec))); + verify(mockPluginService, times(1)) + .setCurrentConnection(eq(mockWriterConn), eq(writerHostSpec)); + assertEquals(mockReaderConn1, plugin.getReaderConnection()); + assertEquals(mockWriterConn, plugin.getWriterConnection()); + } + + @Test + public void testSetReadOnlyTrue_alreadyOnReader() throws SQLException { + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); + when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + mockReaderConn1); + + plugin.switchConnectionIfRequired(true); + + verify(mockPluginService, times(0)) + .setCurrentConnection(any(Connection.class), any(HostSpec.class)); + assertEquals(mockReaderConn1, plugin.getReaderConnection()); + assertNull(plugin.getWriterConnection()); + } + + @Test + public void testSetReadOnlyFalse_alreadyOnWriter() throws SQLException { + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); + when(mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + null); + plugin.switchConnectionIfRequired(false); + + verify(mockPluginService, times(0)) + .setCurrentConnection(any(Connection.class), any(HostSpec.class)); + assertEquals(mockWriterConn, plugin.getWriterConnection()); + assertNull(plugin.getReaderConnection()); + } + + @Test + public void testSetReadOnly_falseInTransaction() { + when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); + when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.isInTransaction()).thenReturn(true); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + mockReaderConn1); + + final SQLException e = + assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); + assertEquals(SqlState.ACTIVE_SQL_TRANSACTION.getState(), e.getSQLState()); + } + + @Test + public void testSetReadOnly_true() throws SQLException { + final ReadWriteSplittingPlugin plugin = + new ReadWriteSplittingPlugin(mockPluginService, defaultProps); + plugin.switchConnectionIfRequired(true); + + assertEquals(mockReaderConn1, plugin.getReaderConnection()); + } + + @Test + public void testSetReadOnly_false() throws SQLException { + when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); + when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + mockReaderConn1); + plugin.switchConnectionIfRequired(false); + + assertEquals(mockWriterConn, plugin.getWriterConnection()); + } + + @Test + public void testSetReadOnly_true_oneHost() throws SQLException { + when(this.mockPluginService.getHosts()).thenReturn(Collections.singletonList(writerHostSpec)); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + null); + plugin.switchConnectionIfRequired(true); + + verify(mockPluginService, times(0)) + .setCurrentConnection(any(Connection.class), any(HostSpec.class)); + assertEquals(mockWriterConn, plugin.getWriterConnection()); + assertNull(plugin.getReaderConnection()); + } + + @Test + public void testSetReadOnly_false_writerConnectionFails() throws SQLException { + when(mockPluginService.connect(eq(writerHostSpec), eq(defaultProps), any())) + .thenThrow(SQLException.class); + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); + when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + mockReaderConn1); + + final SQLException e = + assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(false)); + assertEquals(SqlState.CONNECTION_UNABLE_TO_CONNECT.getState(), e.getSQLState()); + verify(mockPluginService, times(0)) + .setCurrentConnection(any(Connection.class), any(HostSpec.class)); + } + + @Test + public void testSetReadOnly_true_readerConnectionFailed() throws SQLException { + when(this.mockPluginService.connect(eq(readerHostSpec1), eq(defaultProps), any())) + .thenThrow(SQLException.class); + when(this.mockPluginService.connect(eq(readerHostSpec2), eq(defaultProps), any())) + .thenThrow(SQLException.class); + when(this.mockPluginService.connect(eq(readerHostSpec3), eq(defaultProps), any())) + .thenThrow(SQLException.class); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + null); + plugin.switchConnectionIfRequired(true); + + verify(mockPluginService, times(0)) + .setCurrentConnection(any(Connection.class), any(HostSpec.class)); + assertNull(plugin.getReaderConnection()); + } + + @Test + public void testSetReadOnlyOnClosedConnection() throws SQLException { + when(mockPluginService.getCurrentConnection()).thenReturn(mockClosedWriterConn); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockClosedWriterConn, + null); + + final SQLException e = + assertThrows(SQLException.class, () -> plugin.switchConnectionIfRequired(true)); + assertEquals(SqlState.CONNECTION_NOT_OPEN.getState(), e.getSQLState()); + verify(mockPluginService, times(0)) + .setCurrentConnection(any(Connection.class), any(HostSpec.class)); + assertNull(plugin.getReaderConnection()); + } + + @Test + public void testExecute_failoverToNewWriter() throws SQLException { + when(mockSqlFunction.call()).thenThrow(FailoverSuccessSQLException.class); + when(mockPluginService.getCurrentConnection()).thenReturn(mockNewWriterConn); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + null); + + assertThrows( + SQLException.class, + () -> plugin.execute( + ResultSet.class, + SQLException.class, + mockStatement, + "Statement.executeQuery", + mockSqlFunction, + new Object[] { + "begin"})); + verify(mockWriterConn, times(1)).close(); + } + + @Test + public void testNotifyConnectionChange() { + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + null); + + final OldConnectionSuggestedAction suggestion = plugin.notifyConnectionChanged(mockChanges); + + assertEquals(mockWriterConn, plugin.getWriterConnection()); + assertEquals(OldConnectionSuggestedAction.NO_OPINION, suggestion); + } + + @Test + public void testConnectNonInitialConnection() throws SQLException { + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + null); + + final Connection connection = + plugin.connect(TEST_PROTOCOL, writerHostSpec, defaultProps, false, this.mockConnectFunc); + + assertEquals(mockWriterConn, connection); + verify(mockConnectFunc).call(); + verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); + } + + @Test + public void testConnectRdsInstanceUrl() throws SQLException { + when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); + when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + null); + final Connection connection = plugin.connect( + TEST_PROTOCOL, + instanceUrlHostSpec, + defaultProps, + true, + this.mockConnectFunc); + + assertEquals(mockReaderConn1, connection); + verify(mockConnectFunc).call(); + verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); + } + + @Test + public void testConnectReaderIpUrl() throws SQLException { + when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); + when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(readerHostSpecWithIncorrectRole); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + null); + final Connection connection = + plugin.connect(TEST_PROTOCOL, ipUrlHostSpec, defaultProps, true, this.mockConnectFunc); + + assertEquals(mockReaderConn1, connection); + verify(mockConnectFunc).call(); + verify(mockHostListProviderService, times(1)).setInitialConnectionHostSpec(eq(readerHostSpec1)); + } + + @Test + public void testConnectClusterUrl() throws SQLException { + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + null); + final Connection connection = + plugin.connect(TEST_PROTOCOL, clusterUrlHostSpec, defaultProps, true, this.mockConnectFunc); + + assertEquals(mockWriterConn, connection); + verify(mockConnectFunc).call(); + verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); + } + + @Test + public void testConnect_errorUpdatingHostSpec() throws SQLException { + when(this.mockConnectFunc.call()).thenReturn(mockReaderConn1); + when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(null); + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + null); + + assertThrows( + SQLException.class, + () -> plugin.connect( + TEST_PROTOCOL, + ipUrlHostSpec, + defaultProps, + true, + this.mockConnectFunc)); + verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); + } + + @Test + public void testExecuteClearWarnings() throws SQLException { + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + mockReaderConn1); + + plugin.execute( + ResultSet.class, + SQLException.class, + mockStatement, + "Connection.clearWarnings", + mockSqlFunction, + new Object[] {} + ); + verify(mockWriterConn, times(1)).clearWarnings(); + verify(mockReaderConn1, times(1)).clearWarnings(); + } + + @Test + public void testExecuteClearWarningsOnClosedConnectionsIsNotCalled() throws SQLException { + when(mockWriterConn.isClosed()).thenReturn(true); + when(mockReaderConn1.isClosed()).thenReturn(true); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + mockReaderConn1); + + plugin.execute( + ResultSet.class, + SQLException.class, + mockStatement, + "Connection.clearWarnings", + mockSqlFunction, + new Object[] {} + ); + verify(mockWriterConn, never()).clearWarnings(); + verify(mockReaderConn1, never()).clearWarnings(); + } + + @Test + public void testExecuteClearWarningsOnNullConnectionsIsNotCalled() throws SQLException { + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + null); + + // calling clearWarnings() on nullified connection would throw an exception + assertDoesNotThrow(() -> { + plugin.execute( + ResultSet.class, + SQLException.class, + mockStatement, + "Connection.clearWarnings", + mockSqlFunction, + new Object[] {} + ); + }); + } + + @Test + public void testClosePooledReaderConnectionAfterSetReadOnly() throws SQLException { + doReturn(writerHostSpec) + .doReturn(writerHostSpec) + .doReturn(readerHostSpec1) + .when(this.mockPluginService).getCurrentHostSpec(); + doReturn(mockReaderConn1).when(mockPluginService).connect(readerHostSpec1, null); + when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); + when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + mockWriterConn, + null); + final ReadWriteSplittingPlugin spyPlugin = spy(plugin); + + spyPlugin.switchConnectionIfRequired(true); + spyPlugin.switchConnectionIfRequired(false); + + verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockReaderConn1)); + } + + @Test + public void testClosePooledWriterConnectionAfterSetReadOnly() throws SQLException { + doReturn(writerHostSpec) + .doReturn(writerHostSpec) + .doReturn(readerHostSpec1) + .doReturn(readerHostSpec1) + .doReturn(writerHostSpec) + .when(this.mockPluginService).getCurrentHostSpec(); + doReturn(mockWriterConn).when(mockPluginService).connect(writerHostSpec, null); + when(mockPluginService.getDriverProtocol()).thenReturn("jdbc:postgresql://"); + when(mockPluginService.isPooledConnectionProvider(any(), any())).thenReturn(true); + + final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( + mockPluginService, + defaultProps, + mockHostListProviderService, + null, + null); + final ReadWriteSplittingPlugin spyPlugin = spy(plugin); + + spyPlugin.switchConnectionIfRequired(true); + spyPlugin.switchConnectionIfRequired(false); + spyPlugin.switchConnectionIfRequired(true); + + verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockWriterConn)); + } + + private static HikariConfig getHikariConfig(HostSpec hostSpec, Properties props) { + final HikariConfig config = new HikariConfig(); + config.setMaximumPoolSize(3); + config.setInitializationFailTimeout(75000); + config.setConnectionTimeout(10000); + return config; + } + + private static String getPoolKey(HostSpec hostSpec, Properties props) { + final String user = props.getProperty(PropertyDefinition.USER.name); + final String somePropertyValue = props.getProperty("somePropertyValue"); + return hostSpec.getUrl() + user + somePropertyValue; + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index a3a661c56..450b494a7 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -1,315 +1,315 @@ -// /* -// * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// * -// * Licensed under the Apache License, Version 2.0 (the "License"). -// * You may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -// package software.amazon.jdbc.util.monitoring; -// -// import static org.junit.jupiter.api.Assertions.assertEquals; -// import static org.junit.jupiter.api.Assertions.assertNotEquals; -// import static org.junit.jupiter.api.Assertions.assertNotNull; -// import static org.junit.jupiter.api.Assertions.assertNull; -// import static org.junit.jupiter.api.Assertions.assertThrows; -// import static org.mockito.ArgumentMatchers.anyInt; -// import static org.mockito.ArgumentMatchers.eq; -// import static org.mockito.Mockito.doNothing; -// import static org.mockito.Mockito.doReturn; -// import static org.mockito.Mockito.spy; -// -// import java.sql.SQLException; -// import java.util.Collections; -// import java.util.HashSet; -// import java.util.Properties; -// import java.util.concurrent.TimeUnit; -// import org.junit.jupiter.api.AfterEach; -// import org.junit.jupiter.api.BeforeEach; -// import org.junit.jupiter.api.Test; -// import org.mockito.Mock; -// import org.mockito.MockitoAnnotations; -// import software.amazon.jdbc.ConnectionProvider; -// import software.amazon.jdbc.dialect.Dialect; -// import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; -// import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -// import software.amazon.jdbc.util.FullServicesContainer; -// import software.amazon.jdbc.util.events.EventPublisher; -// import software.amazon.jdbc.util.storage.StorageService; -// import software.amazon.jdbc.util.telemetry.TelemetryFactory; -// -// class MonitorServiceImplTest { -// @Mock FullServicesContainer mockServicesContainer; -// @Mock StorageService mockStorageService; -// @Mock ConnectionProvider mockConnectionProvider; -// @Mock TelemetryFactory mockTelemetryFactory; -// @Mock TargetDriverDialect mockTargetDriverDialect; -// @Mock Dialect mockDbDialect; -// @Mock EventPublisher mockPublisher; -// String url = "jdbc:postgresql://somehost/somedb"; -// String protocol = "someProtocol"; -// Properties props = new Properties(); -// MonitorServiceImpl spyMonitorService; -// private AutoCloseable closeable; -// -// @BeforeEach -// void setUp() throws SQLException { -// closeable = MockitoAnnotations.openMocks(this); -// spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); -// doNothing().when(spyMonitorService).initCleanupThread(anyInt()); -// doReturn(mockServicesContainer).when(spyMonitorService).getNewServicesContainer( -// eq(mockStorageService), -// eq(mockConnectionProvider), -// eq(mockTelemetryFactory), -// eq(url), -// eq(protocol), -// eq(mockTargetDriverDialect), -// eq(mockDbDialect), -// eq(props)); -// } -// -// @AfterEach -// void tearDown() throws Exception { -// closeable.close(); -// spyMonitorService.releaseResources(); -// } -// -// @Test -// public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// TimeUnit.MINUTES.toNanos(1), -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// url, -// protocol, -// mockTargetDriverDialect, -// mockDbDialect, -// props, -// (mockServicesContainer) -> new NoOpMonitor(30) -// ); -// -// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(storedMonitor); -// assertEquals(monitor, storedMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// -// monitor.state.set(MonitorState.ERROR); -// spyMonitorService.checkMonitors(); -// -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// -// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(newMonitor); -// assertNotEquals(monitor, newMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, newMonitor.getState()); -// } -// -// @Test -// public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// 1, // heartbeat times out immediately -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// url, -// protocol, -// mockTargetDriverDialect, -// mockDbDialect, -// props, -// (mockServicesContainer) -> new NoOpMonitor(30) -// ); -// -// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(storedMonitor); -// assertEquals(monitor, storedMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// -// // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. -// spyMonitorService.checkMonitors(); -// -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// -// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(newMonitor); -// assertNotEquals(monitor, newMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, newMonitor.getState()); -// } -// -// @Test -// public void testMonitorExpired() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms -// TimeUnit.MINUTES.toNanos(1), -// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this -// // indicates it is not being used. -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// url, -// protocol, -// mockTargetDriverDialect, -// mockDbDialect, -// props, -// (mockServicesContainer) -> new NoOpMonitor(30) -// ); -// -// Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// assertNotNull(storedMonitor); -// assertEquals(monitor, storedMonitor); -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// -// // checkMonitors() should detect the expiration timeout and stop/remove the monitor. -// spyMonitorService.checkMonitors(); -// -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// -// Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); -// // monitor should have been removed when checkMonitors() was called. -// assertNull(newMonitor); -// } -// -// @Test -// public void testMonitorMismatch() { -// assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( -// CustomEndpointMonitorImpl.class, -// "testMonitor", -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// url, -// protocol, -// mockTargetDriverDialect, -// mockDbDialect, -// props, -// // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor -// // service should detect this and throw an exception. -// (mockServicesContainer) -> new NoOpMonitor(30) -// )); -// } -// -// @Test -// public void testRemove() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// TimeUnit.MINUTES.toNanos(1), -// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this -// // indicates it is not being used. -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// url, -// protocol, -// mockTargetDriverDialect, -// mockDbDialect, -// props, -// (mockServicesContainer) -> new NoOpMonitor(30) -// ); -// assertNotNull(monitor); -// -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); -// assertEquals(monitor, removedMonitor); -// assertEquals(MonitorState.RUNNING, monitor.getState()); -// } -// -// @Test -// public void testStopAndRemove() throws SQLException, InterruptedException { -// spyMonitorService.registerMonitorTypeIfAbsent( -// NoOpMonitor.class, -// TimeUnit.MINUTES.toNanos(1), -// TimeUnit.MINUTES.toNanos(1), -// // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this -// // indicates it is not being used. -// new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), -// null -// ); -// -// String key = "testMonitor"; -// NoOpMonitor monitor = spyMonitorService.runIfAbsent( -// NoOpMonitor.class, -// key, -// mockStorageService, -// mockTelemetryFactory, -// mockConnectionProvider, -// url, -// protocol, -// mockTargetDriverDialect, -// mockDbDialect, -// props, -// (mockServicesContainer) -> new NoOpMonitor(30) -// ); -// assertNotNull(monitor); -// -// // need to wait to give time for the monitor executor to start the monitor thread. -// TimeUnit.MILLISECONDS.sleep(250); -// spyMonitorService.stopAndRemove(NoOpMonitor.class, key); -// assertNull(spyMonitorService.get(NoOpMonitor.class, key)); -// assertEquals(MonitorState.STOPPED, monitor.getState()); -// } -// -// static class NoOpMonitor extends AbstractMonitor { -// protected NoOpMonitor(long terminationTimeoutSec) { -// super(terminationTimeoutSec); -// } -// -// @Override -// public void monitor() { -// // do nothing. -// } -// } -// } +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.util.monitoring; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; + +import java.sql.SQLException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.events.EventPublisher; +import software.amazon.jdbc.util.storage.StorageService; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class MonitorServiceImplTest { + @Mock FullServicesContainer mockServicesContainer; + @Mock StorageService mockStorageService; + @Mock ConnectionProvider mockConnectionProvider; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock Dialect mockDbDialect; + @Mock EventPublisher mockPublisher; + String url = "jdbc:postgresql://somehost/somedb"; + String protocol = "someProtocol"; + Properties props = new Properties(); + MonitorServiceImpl spyMonitorService; + private AutoCloseable closeable; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + spyMonitorService = spy(new MonitorServiceImpl(mockPublisher)); + doNothing().when(spyMonitorService).initCleanupThread(anyInt()); + doReturn(mockServicesContainer).when(spyMonitorService).getNewServicesContainer( + eq(mockStorageService), + eq(mockConnectionProvider), + eq(mockTelemetryFactory), + eq(url), + eq(protocol), + eq(mockTargetDriverDialect), + eq(mockDbDialect), + eq(props)); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + spyMonitorService.releaseResources(); + } + + @Test + public void testMonitorError_monitorReCreated() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + TimeUnit.MINUTES.toNanos(1), + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + url, + protocol, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(storedMonitor); + assertEquals(monitor, storedMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, monitor.getState()); + + monitor.state.set(MonitorState.ERROR); + spyMonitorService.checkMonitors(); + + assertEquals(MonitorState.STOPPED, monitor.getState()); + + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(newMonitor); + assertNotEquals(monitor, newMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, newMonitor.getState()); + } + + @Test + public void testMonitorStuck_monitorReCreated() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + 1, // heartbeat times out immediately + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + url, + protocol, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(storedMonitor); + assertEquals(monitor, storedMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, monitor.getState()); + + // checkMonitors() should detect the heartbeat/inactivity timeout, stop the monitor, and re-create a new one. + spyMonitorService.checkMonitors(); + + assertEquals(MonitorState.STOPPED, monitor.getState()); + + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(newMonitor); + assertNotEquals(monitor, newMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, newMonitor.getState()); + } + + @Test + public void testMonitorExpired() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MILLISECONDS.toNanos(200), // monitor expires after 200ms + TimeUnit.MINUTES.toNanos(1), + // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this + // indicates it is not being used. + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + url, + protocol, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + + Monitor storedMonitor = spyMonitorService.get(NoOpMonitor.class, key); + assertNotNull(storedMonitor); + assertEquals(monitor, storedMonitor); + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + assertEquals(MonitorState.RUNNING, monitor.getState()); + + // checkMonitors() should detect the expiration timeout and stop/remove the monitor. + spyMonitorService.checkMonitors(); + + assertEquals(MonitorState.STOPPED, monitor.getState()); + + Monitor newMonitor = spyMonitorService.get(NoOpMonitor.class, key); + // monitor should have been removed when checkMonitors() was called. + assertNull(newMonitor); + } + + @Test + public void testMonitorMismatch() { + assertThrows(IllegalStateException.class, () -> spyMonitorService.runIfAbsent( + CustomEndpointMonitorImpl.class, + "testMonitor", + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + url, + protocol, + mockTargetDriverDialect, + mockDbDialect, + props, + // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor + // service should detect this and throw an exception. + (mockServicesContainer) -> new NoOpMonitor(30) + )); + } + + @Test + public void testRemove() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + TimeUnit.MINUTES.toNanos(1), + // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this + // indicates it is not being used. + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + url, + protocol, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + assertNotNull(monitor); + + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + Monitor removedMonitor = spyMonitorService.remove(NoOpMonitor.class, key); + assertEquals(monitor, removedMonitor); + assertEquals(MonitorState.RUNNING, monitor.getState()); + } + + @Test + public void testStopAndRemove() throws SQLException, InterruptedException { + spyMonitorService.registerMonitorTypeIfAbsent( + NoOpMonitor.class, + TimeUnit.MINUTES.toNanos(1), + TimeUnit.MINUTES.toNanos(1), + // even though we pass a re-create policy, we should not re-create it if the monitor is expired since this + // indicates it is not being used. + new HashSet<>(Collections.singletonList(MonitorErrorResponse.RECREATE)), + null + ); + + String key = "testMonitor"; + NoOpMonitor monitor = spyMonitorService.runIfAbsent( + NoOpMonitor.class, + key, + mockStorageService, + mockTelemetryFactory, + mockConnectionProvider, + url, + protocol, + mockTargetDriverDialect, + mockDbDialect, + props, + (mockServicesContainer) -> new NoOpMonitor(30) + ); + assertNotNull(monitor); + + // need to wait to give time for the monitor executor to start the monitor thread. + TimeUnit.MILLISECONDS.sleep(250); + spyMonitorService.stopAndRemove(NoOpMonitor.class, key); + assertNull(spyMonitorService.get(NoOpMonitor.class, key)); + assertEquals(MonitorState.STOPPED, monitor.getState()); + } + + static class NoOpMonitor extends AbstractMonitor { + protected NoOpMonitor(long terminationTimeoutSec) { + super(terminationTimeoutSec); + } + + @Override + public void monitor() { + // do nothing. + } + } +} From f0f3826c6d191fca2d036c71a71fd8a730734b92 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 24 Sep 2025 16:39:06 -0700 Subject: [PATCH 50/54] Tests compiling --- .../ConnectionPluginManagerBenchmarks.java | 22 +- .../jdbc/benchmarks/PluginBenchmarks.java | 10 +- .../amazon/jdbc/ConnectionPlugin.java | 4 +- .../aurora/TestAuroraHostListProvider.java | 6 +- .../jdbc/ConnectionPluginManagerTests.java | 67 +++-- .../amazon/jdbc/DialectDetectionTests.java | 42 +-- .../HikariPooledConnectionProviderTest.java | 44 ++- .../amazon/jdbc/PluginServiceImplTests.java | 275 ++---------------- .../RdsHostListProviderTest.java | 7 +- .../RdsMultiAzDbClusterListProviderTest.java | 7 +- .../amazon/jdbc/mock/TestPluginOne.java | 27 +- .../amazon/jdbc/mock/TestPluginThree.java | 4 +- .../jdbc/mock/TestPluginThrowException.java | 5 +- .../AuroraConnectionTrackerPluginTest.java | 16 +- ...AwsSecretsManagerConnectionPluginTest.java | 58 ++-- .../plugin/DefaultConnectionPluginTest.java | 14 +- .../CustomEndpointPluginTest.java | 8 +- .../dev/DeveloperConnectionPluginTest.java | 2 +- .../HostHostMonitorConnectionContextTest.java | 2 +- .../efm/HostHostMonitorServiceImplTest.java | 2 +- .../HostMonitoringConnectionPluginTest.java | 29 +- ...adedDefaultHostHostMonitorServiceTest.java | 4 +- .../FederatedAuthPluginTest.java | 15 +- .../federatedauth/OktaAuthPluginTest.java | 15 +- .../iam/IamAuthConnectionPluginTest.java | 9 +- .../LimitlessConnectionPluginTest.java | 37 +-- .../ReadWriteSplittingPluginTest.java | 35 +-- .../monitoring/MonitorServiceImplTest.java | 52 +--- 28 files changed, 259 insertions(+), 559 deletions(-) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java index 8f14ad5b3..260aa8e1d 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java @@ -63,13 +63,12 @@ import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.benchmarks.testplugin.BenchmarkPluginFactory; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.profile.ConfigurationProfile; import software.amazon.jdbc.profile.ConfigurationProfileBuilder; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; @@ -90,13 +89,14 @@ public class ConnectionPluginManagerBenchmarks { private static final String FIELD_SERVER_ID = "SERVER_ID"; private static final String FIELD_SESSION_ID = "SESSION_ID"; private static final String url = "protocol//url"; - private ConnectionContext pluginsContext; - private ConnectionContext noPluginsContext; + private ConnectionInfo pluginsContext; + private ConnectionInfo noPluginsContext; private ConnectionPluginManager pluginManager; private ConnectionPluginManager pluginManagerWithNoPlugins; @Mock ConnectionProvider mockConnectionProvider; @Mock ConnectionWrapper mockConnectionWrapper; + @Mock ConnectionInfo mockConnectionInfo; @Mock FullServicesContainer mockServicesContainer; @Mock PluginService mockPluginService; @Mock PluginManagerService mockPluginManagerService; @@ -153,12 +153,12 @@ public void setUpIteration() throws Exception { Properties noPluginsProps = new Properties(); noPluginsProps.setProperty(PropertyDefinition.PLUGINS.name, ""); - this.noPluginsContext = new ConnectionContext(url, mockDriverDialect, noPluginsProps); + this.noPluginsContext = new ConnectionInfo(url, mockDriverDialect, noPluginsProps); Properties pluginsProps = new Properties(); pluginsProps.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); pluginsProps.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); - this.pluginsContext = new ConnectionContext(url, mockDriverDialect, pluginsProps); + this.pluginsContext = new ConnectionInfo(url, mockDriverDialect, pluginsProps); TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(pluginsProps); @@ -182,7 +182,7 @@ public void tearDownIteration() throws Exception { public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws SQLException { final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, mockConnectionWrapper, mockTelemetryFactory); - manager.init(mockServicesContainer, this.noPluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); + manager.init(mockServicesContainer, this.noPluginsContext.getProps(), mockPluginManagerService, configurationProfile); return manager; } @@ -190,16 +190,15 @@ public ConnectionPluginManager initConnectionPluginManagerWithNoPlugins() throws public ConnectionPluginManager initConnectionPluginManagerWithPlugins() throws SQLException { final ConnectionPluginManager manager = new ConnectionPluginManager(mockConnectionProvider, null, mockConnectionWrapper, mockTelemetryFactory); - manager.init(mockServicesContainer, this.pluginsContext.getPropsCopy(), mockPluginManagerService, configurationProfile); + manager.init(mockServicesContainer, this.pluginsContext.getProps(), mockPluginManagerService, configurationProfile); return manager; } @Benchmark public Connection connectWithPlugins() throws SQLException { return pluginManager.connect( - "driverProtocol", + mockConnectionInfo, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), - this.pluginsContext.getPropsCopy(), true, null); } @@ -207,9 +206,8 @@ public Connection connectWithPlugins() throws SQLException { @Benchmark public Connection connectWithNoPlugins() throws SQLException { return pluginManagerWithNoPlugins.connect( - "driverProtocol", + mockConnectionInfo, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), - this.noPluginsContext.getPropsCopy(), true, null); } diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java index 35932705c..6579f336e 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java @@ -63,6 +63,7 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.GaugeCallable; @@ -122,7 +123,7 @@ public static void main(String[] args) throws RunnerException { @Setup(Level.Iteration) public void setUpIteration() throws Exception { closeable = MockitoAnnotations.openMocks(this); - when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean(), any())) + when(mockConnectionPluginManager.connect(any(), any(), anyBoolean(), any())) .thenReturn(mockConnection); when(mockConnectionPluginManager.execute( any(), any(), any(), eq(JdbcMethod.CONNECTION_CREATESTATEMENT), any(), any())) @@ -133,12 +134,7 @@ public void setUpIteration() throws Exception { when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); // noinspection unchecked when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - when(mockConnectionProvider.connect( - anyString(), - any(Dialect.class), - any(TargetDriverDialect.class), - any(HostSpec.class), - any(Properties.class))).thenReturn(mockConnection); + when(mockConnectionProvider.connect(any(ConnectionInfo.class), any(HostSpec.class))).thenReturn(mockConnection); when(mockConnection.createStatement()).thenReturn(mockStatement); when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); when(mockResultSet.next()).thenReturn(true, true, false); diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java index 98ce8bdcb..f0b0c8aa8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java @@ -53,7 +53,7 @@ T execute( * {@link DataSourceConnectionProvider} for connections requested via an * {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionInfo the connection info for the original connection + * @param connectionInfo the connection info for the original connection * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -80,7 +80,7 @@ Connection connect( * requested via the {@link java.sql.DriverManager} and {@link DataSourceConnectionProvider} for * connections requested via an {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionInfo the connection info for the original connection. + * @param connectionInfo the connection info for the original connection. * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has diff --git a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java index c35f6b0f8..c78ce3f16 100644 --- a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java +++ b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java @@ -16,15 +16,15 @@ package integration.container.aurora; -import java.util.Properties; import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class TestAuroraHostListProvider extends AuroraHostListProvider { public TestAuroraHostListProvider( - FullServicesContainer servicesContainer, Properties properties, String originalUrl) { - super(properties, originalUrl, servicesContainer, "", "", ""); + ConnectionInfo connectionInfo, FullServicesContainer servicesContainer) { + super(connectionInfo, servicesContainer, "", "", ""); } public static void clearCache() { diff --git a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java index 355276ad3..6a89a2f1b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java @@ -64,6 +64,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.wrapper.ConnectionWrapper; @@ -75,6 +76,7 @@ public class ConnectionPluginManagerTests { @Mock JdbcCallable mockSqlFunction; @Mock ConnectionProvider mockConnectionProvider; @Mock ConnectionWrapper mockConnectionWrapper; + @Mock ConnectionInfo mockConnectionInfo; @Mock TelemetryFactory mockTelemetryFactory; @Mock TelemetryContext mockTelemetryContext; @Mock FullServicesContainer mockServicesContainer; @@ -240,8 +242,8 @@ public void testConnect() throws Exception { new ConnectionPluginManager(mockConnectionProvider, null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - final Connection conn = target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, + final Connection conn = target.connect(mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null); assertEquals(expectedConnection, conn); @@ -272,9 +274,11 @@ public void testConnectWithSkipPlugin() throws Exception { new ConnectionPluginManager(mockConnectionProvider, null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - final Connection conn = target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, - true, pluginOne); + final Connection conn = target.connect( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + true, + pluginOne); assertEquals(expectedConnection, conn); assertEquals(2, calls.size()); @@ -303,8 +307,9 @@ public void testForceConnect() throws Exception { new ConnectionPluginManager(mockConnectionProvider, null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - final Connection conn = target.forceConnect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, + final Connection conn = target.forceConnect( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null); @@ -335,8 +340,11 @@ public void testConnectWithSQLExceptionBefore() { assertThrows( SQLException.class, - () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); + () -> target.connect( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + true, + null)); assertEquals(2, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -361,8 +369,11 @@ public void testConnectWithSQLExceptionAfter() { assertThrows( SQLException.class, - () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); + () -> target.connect( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + true, + null)); assertEquals(5, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -388,12 +399,13 @@ public void testConnectWithUnexpectedExceptionBefore() { new ConnectionPluginManager(mockConnectionProvider, null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - final Exception ex = - assertThrows( - IllegalArgumentException.class, - () -> target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); + assertThrows( + IllegalArgumentException.class, + () -> target.connect( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + true, + null)); assertEquals(2, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -416,12 +428,13 @@ public void testConnectWithUnexpectedExceptionAfter() { new ConnectionPluginManager(mockConnectionProvider, null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - final Exception ex = - assertThrows( - IllegalArgumentException.class, - () -> target.connect("any", - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true, null)); + assertThrows( + IllegalArgumentException.class, + () -> target.connect( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), + true, + null)); assertEquals(5, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -524,9 +537,8 @@ public void testForceConnectCachedJdbcCallForceConnect() throws Exception { null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); Object result = target.forceConnect( - "any", + mockConnectionInfo, testHostSpec, - testProperties, true, null); @@ -544,9 +556,8 @@ public void testForceConnectCachedJdbcCallForceConnect() throws Exception { calls.clear(); result = target.forceConnect( - "any", + mockConnectionInfo, testHostSpec, - testProperties, true, null); @@ -667,7 +678,7 @@ public void testOverridingDefaultPluginsWithPluginCodes() throws SQLException { } @Test - public void testTwoConnectionsDoNotBlockOneAnother() throws Exception { + public void testTwoConnectionsDoNotBlockOneAnother() { final Properties testProperties = new Properties(); final ArrayList testPlugins = new ArrayList<>(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java index 3b47f12bb..f15bcf6ad 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java @@ -52,6 +52,7 @@ import software.amazon.jdbc.exceptions.ExceptionManager; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.StorageService; public class DialectDetectionTests { @@ -61,18 +62,17 @@ public class DialectDetectionTests { private static final String MYSQL_PROTOCOL = "jdbc:mysql://"; private static final String PG_PROTOCOL = "jdbc:postgresql://"; private static final String MARIA_PROTOCOL = "jdbc:mariadb://"; - private final Properties props = new Properties(); private AutoCloseable closeable; @Mock private FullServicesContainer mockServicesContainer; @Mock private HostListProviderService mockHostListProviderService; @Mock private StorageService mockStorageService; + @Mock private TargetDriverDialect mockDriverDialect; @Mock private Connection mockConnection; @Mock private Statement mockStatement; @Mock private ResultSet mockSuccessResultSet; @Mock private ResultSet mockFailResultSet; @Mock private HostSpec mockHost; @Mock private ConnectionPluginManager mockPluginManager; - @Mock private TargetDriverDialect mockTargetDriverDialect; @Mock private ResultSetMetaData mockResultSetMetaData; @BeforeEach @@ -94,15 +94,17 @@ void cleanUp() throws Exception { } PluginServiceImpl getPluginService(String host, String protocol) throws SQLException { + return getPluginService( + new ConnectionInfo(host + protocol, protocol, mockDriverDialect, new Properties())); + } + + PluginServiceImpl getPluginService(ConnectionInfo connectionInfo) throws SQLException { PluginServiceImpl pluginService = spy( new PluginServiceImpl( mockServicesContainer, new ExceptionManager(), - props, - protocol + host, - protocol, + connectionInfo, null, - mockTargetDriverDialect, null, null)); @@ -113,8 +115,10 @@ PluginServiceImpl getPluginService(String host, String protocol) throws SQLExcep @ParameterizedTest @MethodSource("getInitialDialectArguments") public void testInitialDialectDetection(String protocol, String host, Object expectedDialect) throws SQLException { + final ConnectionInfo connectionInfo = + new ConnectionInfo(protocol + host, protocol, mockDriverDialect, new Properties()); final DialectManager dialectManager = new DialectManager(this.getPluginService(host, protocol)); - final Dialect dialect = dialectManager.getDialect(protocol, host, new Properties()); + final Dialect dialect = dialectManager.getDialect(connectionInfo); assertEquals(expectedDialect, dialect.getClass()); } @@ -138,7 +142,7 @@ void testUpdateDialectMysqlUnchanged() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(MysqlDialect.class, target.dialect.getClass()); + assertEquals(MysqlDialect.class, target.getDialect().getClass()); } @Test @@ -154,7 +158,7 @@ void testUpdateDialectMysqlToRds() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); + assertEquals(RdsMysqlDialect.class, target.getDialect().getClass()); } @Test @@ -168,7 +172,7 @@ void testUpdateDialectMysqlToTaz() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, MYSQL_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); + assertEquals(AuroraMysqlDialect.class, target.getDialect().getClass()); } @Test @@ -180,7 +184,7 @@ void testUpdateDialectMysqlToAurora() throws SQLException { when(mockServicesContainer.getPluginService()).thenReturn(target); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); + assertEquals(AuroraMysqlDialect.class, target.getDialect().getClass()); } @Test @@ -189,7 +193,7 @@ void testUpdateDialectPgUnchanged() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(PgDialect.class, target.dialect.getClass()); + assertEquals(PgDialect.class, target.getDialect().getClass()); } @Test @@ -204,7 +208,7 @@ void testUpdateDialectPgToRds() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(RdsPgDialect.class, target.dialect.getClass()); + assertEquals(RdsPgDialect.class, target.getDialect().getClass()); } @Test @@ -219,7 +223,7 @@ void testUpdateDialectPgToTaz() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(RdsMultiAzDbClusterPgDialect.class, target.dialect.getClass()); + assertEquals(RdsMultiAzDbClusterPgDialect.class, target.getDialect().getClass()); } @Test @@ -234,7 +238,7 @@ void testUpdateDialectPgToAurora() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, PG_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(AuroraPgDialect.class, target.dialect.getClass()); + assertEquals(AuroraPgDialect.class, target.getDialect().getClass()); } @Test @@ -243,7 +247,7 @@ void testUpdateDialectMariaUnchanged() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(MariaDbDialect.class, target.dialect.getClass()); + assertEquals(MariaDbDialect.class, target.getDialect().getClass()); } @Test @@ -259,7 +263,7 @@ void testUpdateDialectMariaToMysqlRds() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(RdsMysqlDialect.class, target.dialect.getClass()); + assertEquals(RdsMysqlDialect.class, target.getDialect().getClass()); } @Test @@ -272,7 +276,7 @@ void testUpdateDialectMariaToMysqlTaz() throws SQLException { final PluginServiceImpl target = getPluginService(LOCALHOST, MARIA_PROTOCOL); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(RdsMultiAzDbClusterMysqlDialect.class, target.dialect.getClass()); + assertEquals(RdsMultiAzDbClusterMysqlDialect.class, target.getDialect().getClass()); } @Test @@ -284,6 +288,6 @@ void testUpdateDialectMariaToMysqlAurora() throws SQLException { when(mockServicesContainer.getPluginService()).thenReturn(target); target.setInitialConnectionHostSpec(mockHost); target.updateDialect(mockConnection); - assertEquals(AuroraMysqlDialect.class, target.dialect.getClass()); + assertEquals(AuroraMysqlDialect.class, target.getDialect().getClass()); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java index 6e5844ccf..cc266443b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java @@ -45,19 +45,19 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.targetdriverdialect.ConnectInfo; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Pair; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.storage.SlidingExpirationCache; class HikariPooledConnectionProviderTest { @Mock Connection mockConnection; + @Mock ConnectionInfo mockConnectionInfo; @Mock HikariDataSource mockDataSource; @Mock HostSpec mockHostSpec; @Mock HikariConfig mockConfig; - @Mock Dialect mockDialect; @Mock TargetDriverDialect mockTargetDriverDialect; @Mock HikariDataSource dsWithNoConnections; @Mock HikariDataSource dsWith1Connection; @@ -129,24 +129,21 @@ void tearDown() throws Exception { void testConnectWithDefaultMapping() throws SQLException { when(mockHostSpec.getUrl()).thenReturn("url"); final Set expectedUrls = new HashSet<>(Collections.singletonList("url")); - final Set expectedKeys = new HashSet<>( + final Set> expectedKeys = new HashSet<>( Collections.singletonList(Pair.create("url", user1))); provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); - doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); + doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any()); doReturn(new ConnectInfo("url", new Properties())) .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); - Properties props = new Properties(); - props.setProperty(PropertyDefinition.USER.name, user1); - props.setProperty(PropertyDefinition.PASSWORD.name, password); - try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { + try (Connection conn = provider.connect(mockConnectionInfo, mockHostSpec)) { assertEquals(mockConnection, conn); assertEquals(1, provider.getHostCount()); final Set hosts = provider.getHosts(); assertEquals(expectedUrls, hosts); - final Set keys = provider.getKeys(); + final Set> keys = provider.getKeys(); assertEquals(expectedKeys, keys); } } @@ -154,22 +151,22 @@ void testConnectWithDefaultMapping() throws SQLException { @Test void testConnectWithCustomMapping() throws SQLException { when(mockHostSpec.getUrl()).thenReturn("url"); - final Set expectedKeys = new HashSet<>( + final Set> expectedKeys = new HashSet<>( Collections.singletonList(Pair.create("url", "url+someUniqueKey"))); provider = spy(new HikariPooledConnectionProvider( (hostSpec, properties) -> mockConfig, (hostSpec, properties) -> hostSpec.getUrl() + "+someUniqueKey")); - doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any(), any()); + doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any()); Properties props = new Properties(); props.setProperty(PropertyDefinition.USER.name, user1); props.setProperty(PropertyDefinition.PASSWORD.name, password); - try (Connection conn = provider.connect(protocol, mockDialect, mockTargetDriverDialect, mockHostSpec, props)) { + try (Connection conn = provider.connect(mockConnectionInfo, mockHostSpec)) { assertEquals(mockConnection, conn); assertEquals(1, provider.getHostCount()); - final Set keys = provider.getKeys(); + final Set> keys = provider.getKeys(); assertEquals(expectedKeys, keys); } } @@ -180,12 +177,13 @@ public void testAcceptsUrl() { provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); assertTrue( - provider.acceptsUrl(protocol, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(readerUrl2Connection).build(), - defaultProps)); + provider.acceptsUrl( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(readerUrl2Connection).build())); assertFalse( - provider.acceptsUrl(protocol, - new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(clusterUrl).build(), defaultProps)); + provider.acceptsUrl( + mockConnectionInfo, + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(clusterUrl).build())); } @Test @@ -208,8 +206,8 @@ public void testLeastConnectionsStrategy() throws SQLException { assertEquals(readerUrl1Connection, selectedHost.getHost()); } - private SlidingExpirationCache getTestPoolMap() { - SlidingExpirationCache map = new SlidingExpirationCache<>(); + private SlidingExpirationCache, AutoCloseable> getTestPoolMap() { + SlidingExpirationCache, AutoCloseable> map = new SlidingExpirationCache<>(); map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user1), (key) -> dsWith1Connection, TimeUnit.MINUTES.toNanos(10)); map.computeIfAbsent(Pair.create(readerHost2Connection.getUrl(), user2), @@ -227,7 +225,7 @@ public void testConfigurePool() throws SQLException { doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); - provider.configurePool(mockConfig, protocol, readerHost1Connection, defaultProps, mockTargetDriverDialect); + provider.configurePool(mockConfig, mockConnectionInfo, readerHost1Connection, defaultProps); verify(mockConfig).setJdbcUrl(expectedJdbcUrl); verify(mockConfig).setUsername(user1); verify(mockConfig).setPassword(password); @@ -238,10 +236,10 @@ public void testConnectToDeletedInstance() throws SQLException { provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); doReturn(mockDataSource).when(provider) - .createHikariDataSource(eq(protocol), eq(readerHost1Connection), eq(defaultProps), eq(mockTargetDriverDialect)); + .createHikariDataSource(eq(mockConnectionInfo), eq(readerHost1Connection), eq(defaultProps)); when(mockDataSource.getConnection()).thenThrow(SQLException.class); assertThrows(SQLException.class, - () -> provider.connect(protocol, mockDialect, mockTargetDriverDialect, readerHost1Connection, defaultProps)); + () -> provider.connect(mockConnectionInfo, readerHost1Connection)); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java index 07a22941f..0e99b9eb4 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java @@ -42,7 +42,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.Set; import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; @@ -57,25 +56,17 @@ import org.mockito.MockitoAnnotations; import software.amazon.jdbc.dialect.AuroraPgDialect; import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.dialect.DialectManager; import software.amazon.jdbc.dialect.MysqlDialect; -import software.amazon.jdbc.exceptions.ExceptionManager; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.profile.ConfigurationProfile; -import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -import software.amazon.jdbc.states.SessionStateService; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.storage.TestStorageServiceImpl; public class PluginServiceImplTests { - private static final Properties PROPERTIES = new Properties(); - private static final String URL = "url"; - private static final String DRIVER_PROTOCOL = "driverProtocol"; private StorageService storageService; private AutoCloseable closeable; @@ -85,12 +76,9 @@ public class PluginServiceImplTests { @Mock Connection newConnection; @Mock Connection oldConnection; @Mock HostListProvider hostListProvider; - @Mock DialectManager dialectManager; - @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock ConnectionInfo mockConnectionInfo; @Mock Statement statement; @Mock ResultSet resultSet; - ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); - @Mock SessionStateService sessionStateService; @Captor ArgumentCaptor> argumentChanges; @Captor ArgumentCaptor>> argumentChangesMap; @@ -121,16 +109,7 @@ public void testOldConnectionNoSuggestion() throws SQLException { .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") .build(); @@ -144,22 +123,17 @@ public void testOldConnectionNoSuggestion() throws SQLException { verify(oldConnection, times(1)).close(); } + protected PluginServiceImpl getPluginService() throws SQLException { + return new PluginServiceImpl(servicesContainer, mockConnectionInfo); + } + @Test public void testOldConnectionDisposeSuggestion() throws SQLException { when(pluginManager.notifyConnectionChanged(any(), any())) .thenReturn(EnumSet.of(OldConnectionSuggestedAction.DISPOSE)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") .build(); @@ -179,16 +153,7 @@ public void testOldConnectionPreserveSuggestion() throws SQLException { .thenReturn(EnumSet.of(OldConnectionSuggestedAction.PRESERVE)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") .build(); @@ -212,16 +177,7 @@ public void testOldConnectionMixedSuggestion() throws SQLException { OldConnectionSuggestedAction.DISPOSE)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("old-host") .build(); @@ -242,16 +198,7 @@ public void testChangesNewConnectionNewHostNewPortNewRoleNewAvailability() throw .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("old-host").port(1000).role(HostRole.WRITER).availability(HostAvailability.AVAILABLE).build(); @@ -281,16 +228,7 @@ public void testChangesNewConnectionNewRoleNewAvailability() throws SQLException .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) @@ -320,16 +258,7 @@ public void testChangesNewConnection() throws SQLException { .thenReturn(EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) @@ -359,16 +288,7 @@ public void testChangesNoChanges() throws SQLException { EnumSet.of(OldConnectionSuggestedAction.NO_OPINION)); PluginServiceImpl target = - spy(new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + spy(getPluginService()); target.currentConnection = oldConnection; target.currentHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("old-host").port(1000).role(HostRole.READER).availability(HostAvailability.AVAILABLE).build(); @@ -390,16 +310,7 @@ public void testSetNodeListAdded() throws SQLException { new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build())); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = new ArrayList<>(); target.hostListProvider = hostListProvider; @@ -424,16 +335,7 @@ public void testSetNodeListDeleted() throws SQLException { new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build())); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Arrays.asList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build(), new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build()); @@ -461,16 +363,7 @@ public void testSetNodeListChanged() throws SQLException { .port(HostSpec.NO_PORT).role(HostRole.READER).build())); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.WRITER).build()); target.hostListProvider = hostListProvider; @@ -498,16 +391,7 @@ public void testSetNodeListNoChanges() throws SQLException { .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build())); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build()); target.hostListProvider = hostListProvider; @@ -524,16 +408,7 @@ public void testNodeAvailabilityNotChanged() throws SQLException { doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) @@ -553,16 +428,7 @@ public void testNodeAvailabilityChanged_WentDown() throws SQLException { doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) @@ -589,16 +455,7 @@ public void testNodeAvailabilityChanged_WentUp() throws SQLException { doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) @@ -636,16 +493,7 @@ public void testNodeAvailabilityChanged_WentUp_ByAlias() throws SQLException { hostB.addAlias("hostB.custom.domain.com"); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Arrays.asList(hostA, hostB); @@ -671,7 +519,7 @@ public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQL final HostSpec hostA = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) - .build();; + .build(); hostA.addAlias("ip-10-10-10-10"); hostA.addAlias("hostA.custom.domain.com"); final HostSpec hostB = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) @@ -681,16 +529,7 @@ public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQL hostB.addAlias("hostB.custom.domain.com"); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.allHosts = Arrays.asList(hostA, hostB); @@ -759,16 +598,7 @@ void testRefreshHostList_withCachedHostAvailability() throws SQLException { when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs2); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); when(target.getHostListProvider()).thenReturn(hostListProvider); assertNotEquals(expectedHostSpecs, newHostSpecs); @@ -816,16 +646,7 @@ void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { when(hostListProvider.forceRefresh(newConnection)).thenReturn(newHostSpecs); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); when(target.getHostListProvider()).thenReturn(hostListProvider); assertNotEquals(expectedHostSpecs, newHostSpecs); @@ -841,16 +662,7 @@ void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { @Test void testIdentifyConnectionWithNoAliases() throws SQLException { PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); when(target.getHostListProvider()).thenReturn(hostListProvider); when(target.getDialect()).thenReturn(new MysqlDialect()); @@ -862,16 +674,7 @@ void testIdentifyConnectionWithAliases() throws SQLException { final HostSpec expected = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("test") .build(); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.hostListProvider = hostListProvider; when(target.getHostListProvider()).thenReturn(hostListProvider); when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(expected); @@ -890,16 +693,7 @@ void testFillAliasesNonEmptyAliases() throws SQLException { oneAlias.addAlias(oneAlias.asAlias()); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); assertEquals(1, oneAlias.getAliases().size()); target.fillAliases(newConnection, oneAlias); @@ -912,16 +706,7 @@ void testFillAliasesNonEmptyAliases() throws SQLException { void testFillAliasesWithInstanceEndpoint(Dialect dialect, String[] expectedInstanceAliases) throws SQLException { final HostSpec empty = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo").build(); PluginServiceImpl target = spy( - new PluginServiceImpl( - servicesContainer, - new ExceptionManager(), - PROPERTIES, - URL, - DRIVER_PROTOCOL, - dialectManager, - mockTargetDriverDialect, - configurationProfile, - sessionStateService)); + getPluginService()); target.hostListProvider = hostListProvider; when(target.getDialect()).thenReturn(dialect); when(resultSet.next()).thenReturn(true, false); // Result set contains 1 row. diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java index 797d151be..d064dfab2 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java @@ -64,7 +64,9 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.storage.TestStorageServiceImpl; @@ -81,6 +83,7 @@ class RdsHostListProviderTest { @Mock private HostListProviderService mockHostListProviderService; @Mock private EventPublisher mockEventPublisher; @Mock Dialect mockTopologyAwareDialect; + @Mock TargetDriverDialect mockDriverDialect; @Captor private ArgumentCaptor queryCaptor; private AutoCloseable closeable; @@ -115,9 +118,9 @@ void tearDown() throws Exception { } private RdsHostListProvider getRdsHostListProvider(String originalUrl) throws SQLException { + ConnectionInfo connectionInfo = new ConnectionInfo(originalUrl, mockDriverDialect, new Properties()); RdsHostListProvider provider = new RdsHostListProvider( - new Properties(), - originalUrl, + connectionInfo, mockServicesContainer, "foo", "bar", "baz"); provider.init(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java index df6d6ee50..e699d0917 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java @@ -58,7 +58,9 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.storage.TestStorageServiceImpl; @@ -75,6 +77,7 @@ class RdsMultiAzDbClusterListProviderTest { @Mock private HostListProviderService mockHostListProviderService; @Mock private EventPublisher mockEventPublisher; @Mock Dialect mockTopologyAwareDialect; + @Mock TargetDriverDialect mockDriverDialect; @Captor private ArgumentCaptor queryCaptor; private AutoCloseable closeable; @@ -108,9 +111,9 @@ void tearDown() throws Exception { } private RdsMultiAzDbClusterListProvider getRdsMazDbClusterHostListProvider(String originalUrl) throws SQLException { + ConnectionInfo connectionInfo = new ConnectionInfo(originalUrl, mockDriverDialect, new Properties()); RdsMultiAzDbClusterListProvider provider = new RdsMultiAzDbClusterListProvider( - new Properties(), - originalUrl, + connectionInfo, mockServicesContainer, "foo", "bar", diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java index d92930cd9..d0093d090 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java @@ -19,12 +19,11 @@ import java.sql.Connection; import java.sql.SQLException; import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collections; import java.util.EnumSet; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Properties; import java.util.Set; import software.amazon.jdbc.ConnectionPlugin; import software.amazon.jdbc.HostListProviderService; @@ -35,7 +34,7 @@ import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.OldConnectionSuggestedAction; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class TestPluginOne implements ConnectionPlugin { @@ -47,7 +46,7 @@ public class TestPluginOne implements ConnectionPlugin { public TestPluginOne(ArrayList calls) { this.calls = calls; - this.subscribedMethods = new HashSet<>(Arrays.asList("*")); + this.subscribedMethods = new HashSet<>(Collections.singletonList("*")); } @Override @@ -86,7 +85,7 @@ public T execute( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -99,12 +98,10 @@ public Connection connect( @Override public Connection forceConnect( - String driverProtocol, - HostSpec hostSpec, - Properties props, - boolean isInitialConnection, - JdbcCallable forceConnectFunc) - throws SQLException { + final ConnectionInfo connectionInfo, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable forceConnectFunc) throws SQLException { this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); Connection result = forceConnectFunc.call(); @@ -133,13 +130,9 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin @Override public void initHostProvider( - String driverProtocol, - String initialUrl, - Properties props, + ConnectionInfo connectionInfo, HostListProviderService hostListProviderService, - JdbcCallable initHostProviderFunc) - throws SQLException { - + JdbcCallable initHostProviderFunc) { // do nothing } diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java index 876444bc8..c26cc1dab 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java @@ -25,7 +25,7 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class TestPluginThree extends TestPluginOne { @@ -46,7 +46,7 @@ public TestPluginThree(ArrayList calls, Connection connection) { @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java index 91f7e6260..9412d3b23 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java @@ -21,10 +21,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; -import java.util.Properties; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.util.connection.ConnectionContext; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class TestPluginThrowException extends TestPluginOne { @@ -77,7 +76,7 @@ public T execute( @Override public Connection connect( - final ConnectionContext connectionContext, + final ConnectionInfo connectionInfo, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java index 70d62ae9f..8af8f65d0 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java @@ -54,6 +54,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class AuroraConnectionTrackerPluginTest { @@ -68,7 +69,7 @@ public class AuroraConnectionTrackerPluginTest { @Mock JdbcCallable mockConnectionFunction; @Mock JdbcCallable mockSqlFunction; @Mock JdbcCallable mockCloseOrAbortFunction; - @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock TargetDriverDialect mockDriverDialect; private static final Object[] SQL_ARGS = {"sql"}; @@ -86,8 +87,8 @@ void setUp() throws SQLException { when(mockRdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.getDialect()).thenReturn(mockTopologyAwareDialect); - when(mockPluginService.getTargetDriverDialect()).thenReturn(mockTargetDriverDialect); - when(mockTargetDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); + when(mockPluginService.getTargetDriverDialect()).thenReturn(mockDriverDialect); + when(mockDriverDialect.getNetworkBoundMethodNames(any())).thenReturn(new HashSet<>()); } @AfterEach @@ -111,10 +112,11 @@ public void testTrackNewInstanceConnections( mockRdsUtils, mockTracker); + final ConnectionInfo connectionInfo = + new ConnectionInfo(protocol + hostSpec.getHost(), mockDriverDialect, EMPTY_PROPERTIES); final Connection actualConnection = plugin.connect( - protocol, + connectionInfo, hostSpec, - EMPTY_PROPERTIES, isInitialConnection, mockConnectionFunction); @@ -131,10 +133,6 @@ public void testInvalidateOpenedConnectionsWhenWriterHostNotChange() throws SQLE .host("host") .role(HostRole.WRITER) .build(); - final HostSpec newHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("new-host") - .role(HostRole.WRITER) - .build(); // Host list changes during simulated failover when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList(originalHost)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java index 22a8339c1..aa88ed32b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java @@ -62,17 +62,14 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.dialect.DialectManager; import software.amazon.jdbc.exceptions.ExceptionHandler; -import software.amazon.jdbc.exceptions.ExceptionManager; import software.amazon.jdbc.exceptions.MySQLExceptionHandler; import software.amazon.jdbc.exceptions.PgExceptionHandler; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.profile.ConfigurationProfile; -import software.amazon.jdbc.profile.ConfigurationProfileBuilder; -import software.amazon.jdbc.states.SessionStateService; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Pair; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -121,10 +118,7 @@ public class AwsSecretsManagerConnectionPluginTest { @Mock TelemetryContext mockTelemetryContext; @Mock TelemetryCounter mockTelemetryCounter; @Mock TelemetryGauge mockTelemetryGauge; - @Mock TargetDriverDialect mockTargetDriverDialect; - ConfigurationProfile configurationProfile = ConfigurationProfileBuilder.get().withName("test").build(); - - @Mock SessionStateService mockSessionStateService; + @Mock TargetDriverDialect mockDriverDialect; @BeforeEach public void init() throws SQLException { @@ -133,7 +127,7 @@ public void init() throws SQLException { REGION_PROPERTY.set(TEST_PROPS, TEST_REGION); SECRET_ID_PROPERTY.set(TEST_PROPS, TEST_SECRET_ID); - when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) + when(mockDialectManager.getDialect(any(ConnectionInfo.class))) .thenReturn(mockTopologyAwareDialect); when(mockServicesContainer.getConnectionPluginManager()).thenReturn(mockConnectionPluginManager); @@ -151,7 +145,7 @@ public void init() throws SQLException { (host, r) -> mockSecretsManagerClient, (id) -> mockGetValueRequest); - when(mockDialectManager.getDialect(anyString(), anyString(), any(Properties.class))) + when(mockDialectManager.getDialect(any(ConnectionInfo.class))) .thenReturn(mockTopologyAwareDialect); when(mockService.getHostSpecBuilder()).thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); @@ -172,7 +166,7 @@ public void testConnectWithCachedSecrets() throws SQLException { // Add initial cached secret to be used for a connection. AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); - this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); + this.plugin.connect(getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc); assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); @@ -181,6 +175,10 @@ public void testConnectWithCachedSecrets() throws SQLException { assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); } + protected ConnectionInfo getConnectionInfo(String protocol) { + return new ConnectionInfo(protocol, mockDriverDialect, TEST_PROPS); + } + /** * The plugin will attempt to open a connection with an empty secret cache. The plugin will fetch the secret from the * AWS Secrets Manager. @@ -190,7 +188,7 @@ public void testConnectWithNewSecrets() throws SQLException { when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); - this.plugin.connect(TEST_PG_PROTOCOL, TEST_HOSTSPEC, TEST_PROPS, true, this.connectFunc); + this.plugin.connect(getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc); assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); @@ -222,9 +220,8 @@ public void testFailedInitialConnectionWithUnhandledError() throws SQLException final SQLException connectionFailedException = assertThrows( SQLException.class, () -> this.plugin.connect( - TEST_PG_PROTOCOL, + getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, - TEST_PROPS, true, this.connectFunc)); @@ -265,9 +262,8 @@ public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( assertThrows( SQLException.class, () -> this.plugin.connect( - TEST_PG_PROTOCOL, + getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, - TEST_PROPS, true, this.connectFunc)); @@ -279,16 +275,7 @@ public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( } private @NotNull PluginServiceImpl getPluginService(String protocol) throws SQLException { - return new PluginServiceImpl( - mockServicesContainer, - new ExceptionManager(), - TEST_PROPS, - "url", - protocol, - mockDialectManager, - mockTargetDriverDialect, - configurationProfile, - mockSessionStateService); + return new PluginServiceImpl(mockServicesContainer, getConnectionInfo(protocol)); } /** @@ -304,9 +291,8 @@ public void testFailedToReadSecrets() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - TEST_PG_PROTOCOL, + getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, - TEST_PROPS, true, this.connectFunc)); @@ -331,9 +317,8 @@ public void testFailedToGetSecrets() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - TEST_PG_PROTOCOL, + getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, - TEST_PROPS, true, this.connectFunc)); @@ -368,9 +353,8 @@ public void testFailedInitialConnectionWithWrappedGenericError(final String acce assertThrows( SQLException.class, () -> this.plugin.connect( - TEST_PG_PROTOCOL, + getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, - TEST_PROPS, true, this.connectFunc)); @@ -400,9 +384,8 @@ public void testConnectWithWrappedMySQLException() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - TEST_MYSQL_PROTOCOL, + getConnectionInfo(TEST_MYSQL_PROTOCOL), TEST_HOSTSPEC, - TEST_PROPS, true, this.connectFunc)); @@ -432,9 +415,8 @@ public void testConnectWithWrappedPostgreSQLException() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - TEST_PG_PROTOCOL, + getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, - TEST_PROPS, true, this.connectFunc)); @@ -453,7 +435,7 @@ public void testConnectViaARN(final String arn, final Region expectedRegionParse SECRET_ID_PROPERTY.set(props, arn); this.plugin = spy(new AwsSecretsManagerConnectionPlugin( - new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), + new PluginServiceImpl(mockServicesContainer, getConnectionInfo(TEST_PG_PROTOCOL)), props, (host, r) -> mockSecretsManagerClient, (id) -> mockGetValueRequest)); @@ -473,7 +455,7 @@ public void testConnectionWithRegionParameterAndARN(final String arn, final Regi REGION_PROPERTY.set(props, expectedRegion.toString()); this.plugin = spy(new AwsSecretsManagerConnectionPlugin( - new PluginServiceImpl(mockServicesContainer, props, "url", TEST_PG_PROTOCOL, mockTargetDriverDialect), + new PluginServiceImpl(mockServicesContainer, getConnectionInfo(TEST_PG_PROTOCOL)), props, (host, r) -> mockSecretsManagerClient, (id) -> mockGetValueRequest)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java index d8fac47be..070d7028f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java @@ -33,7 +33,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Properties; import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -49,6 +48,7 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -62,6 +62,7 @@ class DefaultConnectionPluginTest { @Mock PluginService pluginService; @Mock ConnectionProvider connectionProvider; @Mock PluginManagerService pluginManagerService; + @Mock ConnectionInfo mockConnectionInfo; @Mock JdbcCallable mockSqlFunction; @Mock JdbcCallable mockConnectFunction; @Mock Connection conn; @@ -86,11 +87,10 @@ void setUp() { when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); // noinspection unchecked when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any())) + when(mockConnectionProviderManager.getConnectionProvider(any(), any())) .thenReturn(connectionProvider); - plugin = new DefaultConnectionPlugin( - pluginService, connectionProvider, null, pluginManagerService, mockConnectionProviderManager); + plugin = new DefaultConnectionPlugin(pluginService, connectionProvider, null, pluginManagerService); } @AfterEach @@ -121,9 +121,9 @@ void testExecute_closeOldConnection() throws SQLException { @Test void testConnect() throws SQLException { - plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction); - verify(connectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any(), any()); - verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any()); + plugin.connect(mockConnectionInfo, mockHostSpec, true, mockConnectFunction); + verify(connectionProvider, atLeastOnce()).connect(any(), any()); + verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), any(), any()); } private static Stream multiStatementQueries() { diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java index 0d41c5f72..902005830 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java @@ -48,6 +48,7 @@ import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -64,6 +65,7 @@ public class CustomEndpointPluginTest { private final HostSpec host = hostSpecBuilder.host(customEndpointUrl).build(); @Mock private FullServicesContainer mockServicesContainer; + @Mock private ConnectionInfo mockConnectionInfo; @Mock private PluginService mockPluginService; @Mock private MonitorService mockMonitorService; @Mock private BiFunction mockRdsClientFunc; @@ -106,7 +108,7 @@ private CustomEndpointPlugin getSpyPlugin() throws SQLException { public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { CustomEndpointPlugin spyPlugin = getSpyPlugin(); - spyPlugin.connect("", writerClusterHost, props, true, mockConnectFunc); + spyPlugin.connect(mockConnectionInfo, writerClusterHost, true, mockConnectFunc); verify(mockConnectFunc, times(1)).call(); verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); @@ -116,7 +118,7 @@ public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLExc public void testConnect_monitorCreated() throws SQLException { CustomEndpointPlugin spyPlugin = getSpyPlugin(); - spyPlugin.connect("", host, props, true, mockConnectFunc); + spyPlugin.connect(mockConnectionInfo, host, true, mockConnectFunc); verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); verify(mockConnectFunc, times(1)).call(); @@ -128,7 +130,7 @@ public void testConnect_timeoutWaitingForInfo() throws SQLException { CustomEndpointPlugin spyPlugin = getSpyPlugin(); when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); - assertThrows(SQLException.class, () -> spyPlugin.connect("", host, props, true, mockConnectFunc)); + assertThrows(SQLException.class, () -> spyPlugin.connect(mockConnectionInfo, host, true, mockConnectFunc)); verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); verify(mockConnectFunc, never()).call(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java index 5c0293bd3..cbfec7329 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java @@ -326,7 +326,7 @@ public void test_RaiseExceptionOnConnectWithCallback() { props.put(DialectManager.DIALECT.name, DialectCodes.PG); final SQLException exception = new SQLException("test"); - when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())) + when(mockConnectCallback.getExceptionToRaise(any(), any(), anyBoolean())) .thenReturn(exception) .thenReturn(null); ExceptionSimulatorManager.setCallback(mockConnectCallback); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorConnectionContextTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorConnectionContextTest.java index 7066bbe3c..290187c50 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorConnectionContextTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorConnectionContextTest.java @@ -35,7 +35,7 @@ import org.mockito.MockitoAnnotations; import software.amazon.jdbc.util.telemetry.TelemetryCounter; -class HostHostMonitorConnectionContextTest { +class HostMonitorConnectionContextTest { private static final long FAILURE_DETECTION_TIME_MILLIS = 10; private static final long FAILURE_DETECTION_INTERVAL_MILLIS = 100; private static final long FAILURE_DETECTION_COUNT = 3; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorServiceImplTest.java index e89279279..c9b1eca94 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorServiceImplTest.java @@ -49,7 +49,7 @@ import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; -class HostHostMonitorServiceImplTest { +class HostMonitorServiceImplTest { private static final Set NODE_KEYS = new HashSet<>(Collections.singletonList("any.node.domain")); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java index 63320e95d..b739e02a8 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java @@ -25,7 +25,6 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.atMostOnce; import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -50,8 +49,6 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import software.amazon.jdbc.HostSpec; @@ -66,6 +63,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionInfo; class HostMonitoringConnectionPluginTest { @@ -79,9 +77,9 @@ class HostMonitoringConnectionPluginTest { @Mock PluginService pluginService; @Mock Dialect mockDialect; @Mock Connection connection; + @Mock ConnectionInfo mockConnectionInfo; @Mock Statement statement; @Mock ResultSet resultSet; - @Captor ArgumentCaptor stringArgumentCaptor; Properties properties = new Properties(); @Mock HostSpec hostSpec; @Mock HostSpec hostSpec2; @@ -96,23 +94,6 @@ class HostMonitoringConnectionPluginTest { private HostMonitoringConnectionPlugin plugin; private AutoCloseable closeable; - /** - * Generate different sets of method arguments where one argument is null to ensure {@link - * software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin#HostMonitoringConnectionPlugin(PluginService, - * Properties)} can handle null arguments correctly. - * - * @return different sets of arguments. - */ - private static Stream generateNullArguments() { - final PluginService pluginService = mock(PluginService.class); - final Properties properties = new Properties(); - - return Stream.of( - Arguments.of(null, null), - Arguments.of(pluginService, null), - Arguments.of(null, properties)); - } - @AfterEach void cleanUp() throws Exception { closeable.close(); @@ -227,7 +208,7 @@ void test_executeMonitoringEnabled() throws Exception { } /** - * Tests exception being thrown in the finally block when checking connection status in the execute method. + * Tests exception being thrown in the `finally` block when checking connection status in the execute method. */ @Test void test_executeCleanUp_whenCheckingConnectionStatus_throwsException() throws SQLException { @@ -248,7 +229,7 @@ void test_executeCleanUp_whenCheckingConnectionStatus_throwsException() throws S } /** - * Tests exception being thrown in the finally block + * Tests exception being thrown in the `finally` block * when an open connection object is detected for an unavailable node in the execute method. */ @Test @@ -282,7 +263,7 @@ void test_connect_exceptionRaisedDuringGenerateHostAliases() throws SQLException doThrow(new SQLException()).when(connection).createStatement(); // Ensure SQLException raised in `generateHostAliases` are ignored. - final Connection conn = plugin.connect("protocol", hostSpec, properties, true, () -> connection); + final Connection conn = plugin.connect(mockConnectionInfo, hostSpec, true, () -> connection); assertNotNull(conn); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostHostMonitorServiceTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostHostMonitorServiceTest.java index b4e76d3ac..afccbcc21 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostHostMonitorServiceTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostHostMonitorServiceTest.java @@ -60,10 +60,10 @@ import software.amazon.jdbc.util.telemetry.TelemetryFactory; /** - * Multithreaded tests for {@link MultiThreadedDefaultHostHostMonitorServiceTest}. Repeats each testcase + * Multithreaded tests for {@link MultiThreadedDefaultHostMonitorServiceTest}. Repeats each testcase * multiple times. Use a cyclic barrier to ensure threads start at the same time. */ -class MultiThreadedDefaultHostHostMonitorServiceTest { +class MultiThreadedDefaultHostMonitorServiceTest { @Mock HostMonitorInitializer monitorInitializer; @Mock ExecutorServiceInitializer executorServiceInitializer; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java index 7b72dae6e..a15d0e511 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java @@ -50,6 +50,7 @@ import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; import software.amazon.jdbc.plugin.iam.IamTokenUtility; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -57,7 +58,6 @@ class FederatedAuthPluginTest { private static final int DEFAULT_PORT = 1234; - private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; private static final HostSpec HOST_SPEC = @@ -77,6 +77,7 @@ class FederatedAuthPluginTest { @Mock private IamTokenUtility mockIamTokenUtils; @Mock private CompletableFuture completableFuture; @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; + @Mock private ConnectionInfo mockConnectionInfo; private Properties props; private AutoCloseable closeable; @@ -119,7 +120,7 @@ void testCachedToken() throws SQLException { String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -136,7 +137,7 @@ void testExpiredCachedToken() throws SQLException { someExpiredToken, Instant.now().minusMillis(300000)); FederatedAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, Region.US_EAST_2, HOST_SPEC.getHost(), @@ -151,7 +152,7 @@ void testNoCachedToken() throws SQLException { FederatedAuthPlugin spyPlugin = Mockito.spy( new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken( mockAwsCredentialsProvider, Region.US_EAST_2, @@ -178,7 +179,7 @@ void testSpecifiedIamHostPortRegion() throws SQLException { FederatedAuthPlugin plugin = new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -197,7 +198,7 @@ void testIdpCredentialsFallback() throws SQLException { String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -211,7 +212,7 @@ public void testUsingIamHost() throws SQLException { FederatedAuthPlugin spyPlugin = Mockito.spy( new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java index 910e06fe1..408f7075f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java @@ -47,6 +47,7 @@ import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; import software.amazon.jdbc.plugin.iam.IamTokenUtility; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -54,7 +55,6 @@ class OktaAuthPluginTest { private static final int DEFAULT_PORT = 1234; - private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com"; private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com"; @@ -73,6 +73,7 @@ class OktaAuthPluginTest { @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; @Mock private RdsUtils mockRdsUtils; @Mock private IamTokenUtility mockIamTokenUtils; + @Mock private ConnectionInfo mockConnectionInfo; private Properties props; private AutoCloseable closeable; @@ -114,7 +115,7 @@ void testCachedToken() throws SQLException { String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -131,7 +132,7 @@ void testExpiredCachedToken() throws SQLException { someExpiredToken, Instant.now().minusMillis(300000)); OktaAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, Region.US_EAST_2, HOST_SPEC.getHost(), @@ -146,7 +147,7 @@ void testNoCachedToken() throws SQLException { final OktaAuthPlugin spyPlugin = new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken( mockAwsCredentialsProvider, Region.US_EAST_2, @@ -173,7 +174,7 @@ void testSpecifiedIamHostPortRegion() throws SQLException { OktaAuthPlugin plugin = new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -192,7 +193,7 @@ void testIdpCredentialsFallback() throws SQLException { final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -206,7 +207,7 @@ public void testUsingIamHost() throws SQLException { OktaAuthPlugin spyPlugin = Mockito.spy( new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java index a872ac96c..eeab6fc0f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java @@ -53,7 +53,9 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.plugin.TokenInfo; +import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -87,6 +89,7 @@ class IamAuthConnectionPluginTest { @Mock TelemetryContext mockTelemetryContext; @Mock JdbcCallable mockLambda; @Mock Dialect mockDialect; + @Mock TargetDriverDialect mockDriverDialect; @Mock private RdsUtils mockRdsUtils; @Mock private IamTokenUtility mockIamTokenUtils; private AutoCloseable closable; @@ -259,7 +262,8 @@ public void testTokenSetInProps(final String protocol, final HostSpec hostSpec) IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); doThrow(new SQLException()).when(mockLambda).call(); - assertThrows(SQLException.class, () -> targetPlugin.connect(protocol, hostSpec, props, true, mockLambda)); + ConnectionInfo connectionInfo = new ConnectionInfo(protocol + hostSpec.getHost(), mockDriverDialect, props); + assertThrows(SQLException.class, () -> targetPlugin.connect(connectionInfo, hostSpec, true, mockLambda)); verify(mockLambda, times(1)).call(); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -278,8 +282,9 @@ private void testGenerateToken( doThrow(new SQLException()).when(mockLambda).call(); + ConnectionInfo connectionInfo = new ConnectionInfo(protocol + hostSpec.getHost(), mockDriverDialect, props); assertThrows(SQLException.class, - () -> spyPlugin.connect(protocol, hostSpec, props, true, mockLambda)); + () -> spyPlugin.connect(connectionInfo, hostSpec, true, mockLambda)); verify(mockIamTokenUtils).generateAuthenticationToken( any(DefaultCredentialsProvider.class), diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java index 411233100..281171b76 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java @@ -16,7 +16,7 @@ package software.amazon.jdbc.plugin.limitless; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -33,10 +33,8 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import software.amazon.jdbc.HostListProvider; -import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.HostSpecBuilder; import software.amazon.jdbc.JdbcCallable; @@ -45,18 +43,17 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.dialect.PgDialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class LimitlessConnectionPluginTest { - private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; private static final HostSpec INPUT_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); private static final String CLUSTER_ID = "someClusterId"; - private static final HostSpec expectedSelectedHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) - .host("expected-selected-instance").role(HostRole.WRITER).weight(Long.MAX_VALUE).build(); private static final Dialect supportedDialect = new AuroraPgDialect(); @Mock JdbcCallable mockConnectFuncLambda; + @Mock ConnectionInfo mockConnectionInfo; @Mock private Connection mockConnection; @Mock private PluginService mockPluginService; @Mock private HostListProvider mockHostListProvider; @@ -86,16 +83,14 @@ void cleanUp() throws Exception { @Test void testConnect() throws SQLException { - doAnswer(new Answer() { - public Void answer(InvocationOnMock invocation) { - LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; - context.setConnection(mockConnection); - return null; - } + doAnswer((Answer) invocation -> { + LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; + context.setConnection(mockConnection); + return null; }).when(mockLimitlessRouterService).establishConnection(any()); final Connection expectedConnection = mockConnection; - final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, + final Connection actualConnection = plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, mockConnectFuncLambda); assertEquals(expectedConnection, actualConnection); @@ -108,17 +103,15 @@ public Void answer(InvocationOnMock invocation) { @Test void testConnectGivenNullConnection() throws SQLException { - doAnswer(new Answer() { - public Void answer(InvocationOnMock invocation) { - LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; - context.setConnection(null); - return null; - } + doAnswer((Answer) invocation -> { + LimitlessConnectionContext context = (LimitlessConnectionContext) invocation.getArguments()[0]; + context.setConnection(null); + return null; }).when(mockLimitlessRouterService).establishConnection(any()); assertThrows( SQLException.class, - () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); + () -> plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, mockConnectFuncLambda)); verify(mockPluginService, times(1)).getDialect(); verify(mockConnectFuncLambda, times(0)).call(); @@ -134,7 +127,7 @@ void testConnectGivenUnsupportedDialect() throws SQLException { assertThrows( UnsupportedOperationException.class, - () -> plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, mockConnectFuncLambda)); + () -> plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, mockConnectFuncLambda)); verify(mockPluginService, times(2)).getDialect(); verify(mockConnectFuncLambda, times(1)).call(); @@ -149,7 +142,7 @@ void testConnectGivenSupportedDialectAfterRefresh() throws SQLException { when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, supportedDialect); final Connection expectedConnection = mockConnection; - final Connection actualConnection = plugin.connect(DRIVER_PROTOCOL, INPUT_HOST_SPEC, props, true, + final Connection actualConnection = plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, mockConnectFuncLambda); assertEquals(expectedConnection, actualConnection); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index c7c7bdc1b..0d9c678d7 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -30,7 +30,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.zaxxer.hikari.HikariConfig; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -53,14 +52,12 @@ import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.OldConnectionSuggestedAction; import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; import software.amazon.jdbc.util.SqlState; +import software.amazon.jdbc.util.connection.ConnectionInfo; public class ReadWriteSplittingPluginTest { - private static final String TEST_PROTOCOL = "jdbc:postgresql:"; private static final int TEST_PORT = 5432; private static final Properties defaultProps = new Properties(); @@ -95,7 +92,6 @@ public class ReadWriteSplittingPluginTest { @Mock private JdbcCallable mockConnectFunc; @Mock private JdbcCallable mockSqlFunction; @Mock private PluginService mockPluginService; - @Mock private Dialect mockDialect; @Mock private HostListProviderService mockHostListProviderService; @Mock private Connection mockWriterConn; @Mock private Connection mockNewWriterConn; @@ -105,6 +101,7 @@ public class ReadWriteSplittingPluginTest { @Mock private Connection mockReaderConn3; @Mock private Statement mockStatement; @Mock private ResultSet mockResultSet; + @Mock private ConnectionInfo mockConnectionInfo; @Mock private EnumSet mockChanges; @BeforeEach @@ -405,7 +402,7 @@ public void testConnectNonInitialConnection() throws SQLException { null); final Connection connection = - plugin.connect(TEST_PROTOCOL, writerHostSpec, defaultProps, false, this.mockConnectFunc); + plugin.connect(mockConnectionInfo, writerHostSpec, false, this.mockConnectFunc); assertEquals(mockWriterConn, connection); verify(mockConnectFunc).call(); @@ -424,9 +421,8 @@ public void testConnectRdsInstanceUrl() throws SQLException { null, null); final Connection connection = plugin.connect( - TEST_PROTOCOL, + mockConnectionInfo, instanceUrlHostSpec, - defaultProps, true, this.mockConnectFunc); @@ -447,7 +443,7 @@ public void testConnectReaderIpUrl() throws SQLException { null, null); final Connection connection = - plugin.connect(TEST_PROTOCOL, ipUrlHostSpec, defaultProps, true, this.mockConnectFunc); + plugin.connect(mockConnectionInfo, ipUrlHostSpec, true, this.mockConnectFunc); assertEquals(mockReaderConn1, connection); verify(mockConnectFunc).call(); @@ -463,7 +459,7 @@ public void testConnectClusterUrl() throws SQLException { null, null); final Connection connection = - plugin.connect(TEST_PROTOCOL, clusterUrlHostSpec, defaultProps, true, this.mockConnectFunc); + plugin.connect(mockConnectionInfo, clusterUrlHostSpec, true, this.mockConnectFunc); assertEquals(mockWriterConn, connection); verify(mockConnectFunc).call(); @@ -484,9 +480,8 @@ public void testConnect_errorUpdatingHostSpec() throws SQLException { assertThrows( SQLException.class, () -> plugin.connect( - TEST_PROTOCOL, + mockConnectionInfo, ipUrlHostSpec, - defaultProps, true, this.mockConnectFunc)); verify(mockHostListProviderService, times(0)).setInitialConnectionHostSpec(any(HostSpec.class)); @@ -538,7 +533,7 @@ public void testExecuteClearWarningsOnClosedConnectionsIsNotCalled() throws SQLE } @Test - public void testExecuteClearWarningsOnNullConnectionsIsNotCalled() throws SQLException { + public void testExecuteClearWarningsOnNullConnectionsIsNotCalled() { final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( mockPluginService, defaultProps, @@ -609,18 +604,4 @@ public void testClosePooledWriterConnectionAfterSetReadOnly() throws SQLExceptio verify(spyPlugin, times(1)).closeConnectionIfIdle(eq(mockWriterConn)); } - - private static HikariConfig getHikariConfig(HostSpec hostSpec, Properties props) { - final HikariConfig config = new HikariConfig(); - config.setMaximumPoolSize(3); - config.setInitializationFailTimeout(75000); - config.setConnectionTimeout(10000); - return config; - } - - private static String getPoolKey(HostSpec hostSpec, Properties props) { - final String user = props.getProperty(PropertyDefinition.USER.name); - final String somePropertyValue = props.getProperty("somePropertyValue"); - return hostSpec.getUrl() + user + somePropertyValue; - } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index 450b494a7..33ef71bcb 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -30,7 +30,6 @@ import java.sql.SQLException; import java.util.Collections; import java.util.HashSet; -import java.util.Properties; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -38,10 +37,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; -import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.connection.ConnectionInfo; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -49,14 +47,10 @@ class MonitorServiceImplTest { @Mock FullServicesContainer mockServicesContainer; @Mock StorageService mockStorageService; + @Mock ConnectionInfo mockConnectionInfo; @Mock ConnectionProvider mockConnectionProvider; @Mock TelemetryFactory mockTelemetryFactory; - @Mock TargetDriverDialect mockTargetDriverDialect; - @Mock Dialect mockDbDialect; @Mock EventPublisher mockPublisher; - String url = "jdbc:postgresql://somehost/somedb"; - String protocol = "someProtocol"; - Properties props = new Properties(); MonitorServiceImpl spyMonitorService; private AutoCloseable closeable; @@ -69,11 +63,7 @@ void setUp() throws SQLException { eq(mockStorageService), eq(mockConnectionProvider), eq(mockTelemetryFactory), - eq(url), - eq(protocol), - eq(mockTargetDriverDialect), - eq(mockDbDialect), - eq(props)); + eq(mockConnectionInfo)); } @AfterEach @@ -98,11 +88,7 @@ public void testMonitorError_monitorReCreated() throws SQLException, Interrupted mockStorageService, mockTelemetryFactory, mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, + mockConnectionInfo, (mockServicesContainer) -> new NoOpMonitor(30) ); @@ -142,11 +128,7 @@ public void testMonitorStuck_monitorReCreated() throws SQLException, Interrupted mockStorageService, mockTelemetryFactory, mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, + mockConnectionInfo, (mockServicesContainer) -> new NoOpMonitor(30) ); @@ -188,11 +170,7 @@ public void testMonitorExpired() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, + mockConnectionInfo, (mockServicesContainer) -> new NoOpMonitor(30) ); @@ -221,11 +199,7 @@ public void testMonitorMismatch() { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, + mockConnectionInfo, // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor // service should detect this and throw an exception. (mockServicesContainer) -> new NoOpMonitor(30) @@ -251,11 +225,7 @@ public void testRemove() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, + mockConnectionInfo, (mockServicesContainer) -> new NoOpMonitor(30) ); assertNotNull(monitor); @@ -286,11 +256,7 @@ public void testStopAndRemove() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - url, - protocol, - mockTargetDriverDialect, - mockDbDialect, - props, + mockConnectionInfo, (mockServicesContainer) -> new NoOpMonitor(30) ); assertNotNull(monitor); From 6b9e23ad9c5fcdbd605ee94ad74c70289cfceb60 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 24 Sep 2025 16:44:48 -0700 Subject: [PATCH 51/54] Checkstyle passing --- .../main/java/software/amazon/jdbc/PartialPluginService.java | 2 +- wrapper/src/main/java/software/amazon/jdbc/PluginService.java | 1 + ...onContextTest.java => HostMonitorConnectionContextTest.java} | 0 ...itorServiceImplTest.java => HostMonitorServiceImplTest.java} | 0 ...est.java => MultiThreadedDefaultHostMonitorServiceTest.java} | 0 5 files changed, 2 insertions(+), 1 deletion(-) rename wrapper/src/test/java/software/amazon/jdbc/plugin/efm/{HostHostMonitorConnectionContextTest.java => HostMonitorConnectionContextTest.java} (100%) rename wrapper/src/test/java/software/amazon/jdbc/plugin/efm/{HostHostMonitorServiceImplTest.java => HostMonitorServiceImplTest.java} (100%) rename wrapper/src/test/java/software/amazon/jdbc/plugin/efm/{MultiThreadedDefaultHostHostMonitorServiceTest.java => MultiThreadedDefaultHostMonitorServiceTest.java} (100%) diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 6658a2eb2..6ff9e9591 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -504,7 +504,7 @@ public Connection forceConnect( final Properties props, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { - return this.pluginManager.forceConnect(this.connectionInfo, hostSpec,true, pluginToSkip); + return this.pluginManager.forceConnect(this.connectionInfo, hostSpec, true, pluginToSkip); } private void updateHostAvailability(final List hosts) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index ffc7a62a5..13e60c794 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -84,6 +84,7 @@ EnumSet setCurrentConnection( /** * Get the {@link ConnectionInfo} for the current original connection. + * * @return the {@link ConnectionInfo} for the current original connection. */ ConnectionInfo getConnectionInfo(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorConnectionContextTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitorConnectionContextTest.java similarity index 100% rename from wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorConnectionContextTest.java rename to wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitorConnectionContextTest.java diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitorServiceImplTest.java similarity index 100% rename from wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostHostMonitorServiceImplTest.java rename to wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitorServiceImplTest.java diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostHostMonitorServiceTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostMonitorServiceTest.java similarity index 100% rename from wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostHostMonitorServiceTest.java rename to wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultHostMonitorServiceTest.java From 0e8eac94d7117981a08d9526273b7241b2746999 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Wed, 24 Sep 2025 17:49:15 -0700 Subject: [PATCH 52/54] Fix unit tests --- .../amazon/jdbc/DialectDetectionTests.java | 2 +- .../HikariPooledConnectionProviderTest.java | 12 +++-- .../amazon/jdbc/PluginServiceImplTests.java | 50 ++++++++----------- .../amazon/jdbc/mock/TestPluginThree.java | 13 ++--- .../AuroraConnectionTrackerPluginTest.java | 8 +-- ...AwsSecretsManagerConnectionPluginTest.java | 4 +- .../plugin/DefaultConnectionPluginTest.java | 3 +- .../FederatedAuthPluginTest.java | 1 + .../federatedauth/OktaAuthPluginTest.java | 1 + .../iam/IamAuthConnectionPluginTest.java | 4 +- 10 files changed, 47 insertions(+), 51 deletions(-) diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java index f15bcf6ad..9ff354b2c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java @@ -95,7 +95,7 @@ void cleanUp() throws Exception { PluginServiceImpl getPluginService(String host, String protocol) throws SQLException { return getPluginService( - new ConnectionInfo(host + protocol, protocol, mockDriverDialect, new Properties())); + new ConnectionInfo(protocol + host, protocol, mockDriverDialect, new Properties())); } PluginServiceImpl getPluginService(ConnectionInfo connectionInfo) throws SQLException { diff --git a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java index cc266443b..cb3e549bf 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java @@ -45,6 +45,7 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.targetdriverdialect.ConnectInfo; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; @@ -58,7 +59,8 @@ class HikariPooledConnectionProviderTest { @Mock HikariDataSource mockDataSource; @Mock HostSpec mockHostSpec; @Mock HikariConfig mockConfig; - @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock Dialect mockDbDialect; + @Mock TargetDriverDialect mockDriverDialect; @Mock HikariDataSource dsWithNoConnections; @Mock HikariDataSource dsWith1Connection; @Mock HikariDataSource dsWith2Connections; @@ -115,6 +117,10 @@ void init() throws SQLException { when(mxBeanWith1Connection.getActiveConnections()).thenReturn(1); when(dsWith2Connections.getHikariPoolMXBean()).thenReturn(mxBeanWith2Connections); when(mxBeanWith2Connections.getActiveConnections()).thenReturn(2); + when(mockConnectionInfo.getDriverDialect()).thenReturn(mockDriverDialect); + when(mockConnectionInfo.getDbDialect()).thenReturn(mockDbDialect); + when(mockConnectionInfo.getProps()).thenReturn(defaultProps); + when(mockConnectionInfo.getProtocol()).thenReturn(protocol); } @AfterEach @@ -136,7 +142,7 @@ void testConnectWithDefaultMapping() throws SQLException { doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any()); doReturn(new ConnectInfo("url", new Properties())) - .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); + .when(mockDriverDialect).prepareConnectInfo(anyString(), any(), any()); try (Connection conn = provider.connect(mockConnectionInfo, mockHostSpec)) { assertEquals(mockConnection, conn); @@ -223,7 +229,7 @@ public void testConfigurePool() throws SQLException { final String expectedJdbcUrl = protocol + readerHost1Connection.getUrl() + db + "?database=" + db; doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) - .when(mockTargetDriverDialect).prepareConnectInfo(anyString(), any(), any()); + .when(mockDriverDialect).prepareConnectInfo(anyString(), any(), any()); provider.configurePool(mockConfig, mockConnectionInfo, readerHost1Connection, defaultProps); verify(mockConfig).setJdbcUrl(expectedJdbcUrl); diff --git a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java index 0e99b9eb4..8ac043902 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java @@ -42,6 +42,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Properties; import java.util.Set; import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; @@ -69,6 +70,7 @@ public class PluginServiceImplTests { private StorageService storageService; private AutoCloseable closeable; + private Properties props = new Properties(); @Mock FullServicesContainer servicesContainer; @Mock EventPublisher mockEventPublisher; @@ -92,6 +94,9 @@ void setUp() throws SQLException { when(statement.executeQuery(any())).thenReturn(resultSet); when(servicesContainer.getConnectionPluginManager()).thenReturn(pluginManager); when(servicesContainer.getStorageService()).thenReturn(storageService); + when(mockConnectionInfo.getProps()).thenReturn(props); + when(mockConnectionInfo.getInitialConnectionString()).thenReturn("url"); + when(mockConnectionInfo.getProtocol()).thenReturn("jdbc:postgresql://"); storageService = new TestStorageServiceImpl(mockEventPublisher); PluginServiceImpl.hostAvailabilityExpiringCache.clear(); } @@ -309,8 +314,7 @@ public void testSetNodeListAdded() throws SQLException { when(hostListProvider.refresh()).thenReturn(Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build())); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = new ArrayList<>(); target.hostListProvider = hostListProvider; @@ -334,8 +338,7 @@ public void testSetNodeListDeleted() throws SQLException { when(hostListProvider.refresh()).thenReturn(Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build())); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Arrays.asList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build(), new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build()); @@ -362,8 +365,7 @@ public void testSetNodeListChanged() throws SQLException { Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA") .port(HostSpec.NO_PORT).role(HostRole.READER).build())); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.WRITER).build()); target.hostListProvider = hostListProvider; @@ -390,8 +392,7 @@ public void testSetNodeListNoChanges() throws SQLException { Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build())); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build()); target.hostListProvider = hostListProvider; @@ -407,8 +408,7 @@ public void testSetNodeListNoChanges() throws SQLException { public void testNodeAvailabilityNotChanged() throws SQLException { doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) @@ -427,8 +427,7 @@ public void testNodeAvailabilityNotChanged() throws SQLException { public void testNodeAvailabilityChanged_WentDown() throws SQLException { doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) @@ -454,8 +453,7 @@ public void testNodeAvailabilityChanged_WentDown() throws SQLException { public void testNodeAvailabilityChanged_WentUp() throws SQLException { doNothing().when(pluginManager).notifyNodeListChanged(argumentChangesMap.capture()); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) @@ -492,8 +490,7 @@ public void testNodeAvailabilityChanged_WentUp_ByAlias() throws SQLException { hostB.addAlias("ip-10-10-10-10"); hostB.addAlias("hostB.custom.domain.com"); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Arrays.asList(hostA, hostB); @@ -528,8 +525,7 @@ public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQL hostB.addAlias("ip-10-10-10-10"); hostB.addAlias("hostB.custom.domain.com"); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.allHosts = Arrays.asList(hostA, hostB); @@ -597,8 +593,7 @@ void testRefreshHostList_withCachedHostAvailability() throws SQLException { when(hostListProvider.refresh()).thenReturn(newHostSpecs); when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs2); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); when(target.getHostListProvider()).thenReturn(hostListProvider); assertNotEquals(expectedHostSpecs, newHostSpecs); @@ -645,8 +640,7 @@ void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { when(hostListProvider.forceRefresh()).thenReturn(newHostSpecs); when(hostListProvider.forceRefresh(newConnection)).thenReturn(newHostSpecs); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); when(target.getHostListProvider()).thenReturn(hostListProvider); assertNotEquals(expectedHostSpecs, newHostSpecs); @@ -661,8 +655,7 @@ void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { @Test void testIdentifyConnectionWithNoAliases() throws SQLException { - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); when(target.getHostListProvider()).thenReturn(hostListProvider); when(target.getDialect()).thenReturn(new MysqlDialect()); @@ -673,8 +666,7 @@ void testIdentifyConnectionWithNoAliases() throws SQLException { void testIdentifyConnectionWithAliases() throws SQLException { final HostSpec expected = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("test") .build(); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.hostListProvider = hostListProvider; when(target.getHostListProvider()).thenReturn(hostListProvider); when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(expected); @@ -692,8 +684,7 @@ void testFillAliasesNonEmptyAliases() throws SQLException { .build(); oneAlias.addAlias(oneAlias.asAlias()); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); assertEquals(1, oneAlias.getAliases().size()); target.fillAliases(newConnection, oneAlias); @@ -705,8 +696,7 @@ void testFillAliasesNonEmptyAliases() throws SQLException { @MethodSource("fillAliasesDialects") void testFillAliasesWithInstanceEndpoint(Dialect dialect, String[] expectedInstanceAliases) throws SQLException { final HostSpec empty = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("foo").build(); - PluginServiceImpl target = spy( - getPluginService()); + PluginServiceImpl target = spy(getPluginService()); target.hostListProvider = hostListProvider; when(target.getDialect()).thenReturn(dialect); when(resultSet.next()).thenReturn(true, false); // Result set contains 1 row. diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java index c26cc1dab..7b1f85689 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java @@ -21,7 +21,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; -import java.util.Properties; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.JdbcMethod; @@ -64,14 +63,12 @@ public Connection connect( return result; } + @Override public Connection forceConnect( - String driverProtocol, - HostSpec hostSpec, - Properties props, - boolean isInitialConnection, - JdbcCallable forceConnectFunc) - throws SQLException { - + final ConnectionInfo connectionInfo, + final HostSpec hostSpec, + final boolean isInitialConnection, + final JdbcCallable forceConnectFunc) throws SQLException { this.calls.add(this.getClass().getSimpleName() + ":before forceConnect"); if (this.connection != null) { diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java index 8af8f65d0..226aacffa 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java @@ -234,10 +234,10 @@ static Stream testInvalidateConnectionsOnCloseOrAbortArgs() { private static Stream trackNewConnectionsParameters() { return Stream.of( - Arguments.of("postgresql", true), - Arguments.of("postgresql", false), - Arguments.of("otherProtocol", true), - Arguments.of("otherProtocol", false) + Arguments.of("jdbc:postgresql://", true), + Arguments.of("jdbc:postgresql://", false), + Arguments.of("jdbc:otherProtocol://", true), + Arguments.of("jdbc:otherProtocol://", false) ); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java index aa88ed32b..0d6f8853f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java @@ -79,8 +79,8 @@ @SuppressWarnings("resource") public class AwsSecretsManagerConnectionPluginTest { - private static final String TEST_PG_PROTOCOL = "jdbc:aws-wrapper:postgresql:"; - private static final String TEST_MYSQL_PROTOCOL = "jdbc:aws-wrapper:mysql:"; + private static final String TEST_PG_PROTOCOL = "jdbc:aws-wrapper:postgresql://"; + private static final String TEST_MYSQL_PROTOCOL = "jdbc:aws-wrapper:mysql://"; private static final String TEST_REGION = "us-east-2"; private static final String TEST_SECRET_ID = "secretId"; private static final String TEST_USERNAME = "testUser"; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java index 070d7028f..e36e9d4a4 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java @@ -90,7 +90,8 @@ void setUp() { when(mockConnectionProviderManager.getConnectionProvider(any(), any())) .thenReturn(connectionProvider); - plugin = new DefaultConnectionPlugin(pluginService, connectionProvider, null, pluginManagerService); + plugin = new DefaultConnectionPlugin( + pluginService, connectionProvider, pluginManagerService, mockConnectionProviderManager); } @AfterEach diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java index a15d0e511..f17fc3e2a 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java @@ -104,6 +104,7 @@ public void init() throws ExecutionException, InterruptedException, SQLException when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) .thenReturn(mockAwsCredentialsProvider); when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); + when(mockConnectionInfo.getProps()).thenReturn(props); when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java index 408f7075f..149d2e5d1 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java @@ -96,6 +96,7 @@ void setUp() throws SQLException { when(mockPluginService.getDialect()).thenReturn(mockDialect); when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockConnectionInfo.getProps()).thenReturn(props); when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java index eeab6fc0f..fe48edd8c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java @@ -71,8 +71,8 @@ class IamAuthConnectionPluginTest { + DEFAULT_PG_PORT + ":postgresqlUser"; private static final String MYSQL_CACHE_KEY = "us-east-2:mysql.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_MYSQL_PORT + ":mysqlUser"; - private static final String PG_DRIVER_PROTOCOL = "jdbc:postgresql:"; - private static final String MYSQL_DRIVER_PROTOCOL = "jdbc:mysql:"; + private static final String PG_DRIVER_PROTOCOL = "jdbc:postgresql://"; + private static final String MYSQL_DRIVER_PROTOCOL = "jdbc:mysql://"; private static final HostSpec PG_HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); private static final HostSpec PG_HOST_SPEC_WITH_PORT = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) From fcad2dfc8f44d452a0c53cdb70e4d9f0b5582cee Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 26 Sep 2025 17:05:19 -0700 Subject: [PATCH 53/54] ConnectionInfo -> ConnectConfig --- .../ConnectionPluginManagerBenchmarks.java | 16 ++--- .../jdbc/benchmarks/PluginBenchmarks.java | 4 +- .../testplugin/BenchmarkPlugin.java | 10 +-- docs/development-guide/LoadablePlugins.md | 4 +- .../jdbc/C3P0PooledConnectionProvider.java | 20 +++--- .../amazon/jdbc/ConnectionPlugin.java | 14 ++-- .../amazon/jdbc/ConnectionPluginManager.java | 20 +++--- .../amazon/jdbc/ConnectionProvider.java | 10 +-- .../jdbc/ConnectionProviderManager.java | 16 ++--- .../jdbc/DataSourceConnectionProvider.java | 26 +++---- .../amazon/jdbc/DriverConnectionProvider.java | 18 ++--- .../jdbc/HikariPooledConnectionProvider.java | 28 ++++---- .../amazon/jdbc/PartialPluginService.java | 48 ++++++------- .../software/amazon/jdbc/PluginService.java | 8 +-- .../amazon/jdbc/PluginServiceImpl.java | 68 +++++++++---------- .../jdbc/dialect/AuroraMysqlDialect.java | 6 +- .../amazon/jdbc/dialect/AuroraPgDialect.java | 6 +- .../amazon/jdbc/dialect/DialectManager.java | 20 +++--- .../amazon/jdbc/dialect/DialectProvider.java | 4 +- .../dialect/HostListProviderSupplier.java | 4 +- .../amazon/jdbc/dialect/MariaDbDialect.java | 4 +- .../amazon/jdbc/dialect/MysqlDialect.java | 4 +- .../amazon/jdbc/dialect/PgDialect.java | 4 +- .../RdsMultiAzDbClusterMysqlDialect.java | 6 +- .../dialect/RdsMultiAzDbClusterPgDialect.java | 6 +- .../amazon/jdbc/dialect/UnknownDialect.java | 4 +- .../AuroraHostListProvider.java | 6 +- .../ConnectionStringHostListProvider.java | 18 ++--- .../hostlistprovider/RdsHostListProvider.java | 14 ++-- .../RdsMultiAzDbClusterListProvider.java | 6 +- .../ClusterTopologyMonitorImpl.java | 2 +- .../MonitoringRdsHostListProvider.java | 14 ++-- .../MonitoringRdsMultiAzHostListProvider.java | 10 +-- .../jdbc/plugin/AbstractConnectionPlugin.java | 8 +-- .../plugin/AuroraConnectionTrackerPlugin.java | 4 +- ...AuroraInitialConnectionStrategyPlugin.java | 8 +-- .../AwsSecretsManagerConnectionPlugin.java | 10 +-- .../plugin/ConnectTimeConnectionPlugin.java | 6 +- .../jdbc/plugin/DefaultConnectionPlugin.java | 20 +++--- .../bluegreen/BlueGreenConnectionPlugin.java | 4 +- .../bluegreen/BlueGreenStatusMonitor.java | 6 +- .../customendpoint/CustomEndpointPlugin.java | 6 +- .../plugin/dev/DeveloperConnectionPlugin.java | 18 ++--- .../ExceptionSimulatorConnectCallback.java | 4 +- .../efm/HostMonitoringConnectionPlugin.java | 4 +- .../plugin/efm2/HostMonitorServiceImpl.java | 2 +- .../efm2/HostMonitoringConnectionPlugin.java | 4 +- .../ClusterAwareReaderFailoverHandler.java | 2 +- .../ClusterAwareWriterFailoverHandler.java | 2 +- .../failover/FailoverConnectionPlugin.java | 8 +-- .../failover2/FailoverConnectionPlugin.java | 12 ++-- .../federatedauth/FederatedAuthPlugin.java | 10 +-- .../plugin/federatedauth/OktaAuthPlugin.java | 10 +-- .../plugin/iam/IamAuthConnectionPlugin.java | 14 ++-- .../limitless/LimitlessConnectionPlugin.java | 6 +- .../limitless/LimitlessRouterServiceImpl.java | 2 +- .../ReadWriteSplittingPlugin.java | 6 +- .../plugin/staledns/AuroraStaleDnsHelper.java | 6 +- .../plugin/staledns/AuroraStaleDnsPlugin.java | 8 +-- .../FastestResponseStrategyPlugin.java | 4 +- .../HostResponseTimeServiceImpl.java | 2 +- .../amazon/jdbc/util/ServiceUtility.java | 8 +-- ...ConnectionInfo.java => ConnectConfig.java} | 6 +- .../connection/ConnectionServiceImpl.java | 12 ++-- .../jdbc/util/monitoring/MonitorService.java | 6 +- .../util/monitoring/MonitorServiceImpl.java | 10 +-- .../jdbc/wrapper/ConnectionWrapper.java | 14 ++-- .../aurora/TestAuroraHostListProvider.java | 6 +- .../aurora/TestPluginServiceImpl.java | 6 +- .../jdbc/ConnectionPluginManagerTests.java | 22 +++--- .../amazon/jdbc/DialectDetectionTests.java | 14 ++-- .../HikariPooledConnectionProviderTest.java | 26 +++---- .../amazon/jdbc/PluginServiceImplTests.java | 12 ++-- .../RdsHostListProviderTest.java | 6 +- .../RdsMultiAzDbClusterListProviderTest.java | 6 +- .../amazon/jdbc/mock/TestPluginOne.java | 8 +-- .../amazon/jdbc/mock/TestPluginThree.java | 6 +- .../jdbc/mock/TestPluginThrowException.java | 4 +- .../AuroraConnectionTrackerPluginTest.java | 8 +-- ...AwsSecretsManagerConnectionPluginTest.java | 34 +++++----- .../plugin/DefaultConnectionPluginTest.java | 6 +- .../CustomEndpointPluginTest.java | 10 +-- .../HostMonitoringConnectionPluginTest.java | 6 +- .../FederatedAuthPluginTest.java | 18 ++--- .../federatedauth/OktaAuthPluginTest.java | 18 ++--- .../iam/IamAuthConnectionPluginTest.java | 10 +-- .../LimitlessConnectionPluginTest.java | 12 ++-- .../ReadWriteSplittingPluginTest.java | 14 ++-- .../monitoring/MonitorServiceImplTest.java | 18 ++--- 89 files changed, 485 insertions(+), 485 deletions(-) rename wrapper/src/main/java/software/amazon/jdbc/util/connection/{ConnectionInfo.java => ConnectConfig.java} (92%) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java index 260aa8e1d..f4f33e368 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java @@ -68,7 +68,7 @@ import software.amazon.jdbc.profile.ConfigurationProfileBuilder; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; @@ -89,14 +89,14 @@ public class ConnectionPluginManagerBenchmarks { private static final String FIELD_SERVER_ID = "SERVER_ID"; private static final String FIELD_SESSION_ID = "SESSION_ID"; private static final String url = "protocol//url"; - private ConnectionInfo pluginsContext; - private ConnectionInfo noPluginsContext; + private ConnectConfig pluginsContext; + private ConnectConfig noPluginsContext; private ConnectionPluginManager pluginManager; private ConnectionPluginManager pluginManagerWithNoPlugins; @Mock ConnectionProvider mockConnectionProvider; @Mock ConnectionWrapper mockConnectionWrapper; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock FullServicesContainer mockServicesContainer; @Mock PluginService mockPluginService; @Mock PluginManagerService mockPluginManagerService; @@ -153,12 +153,12 @@ public void setUpIteration() throws Exception { Properties noPluginsProps = new Properties(); noPluginsProps.setProperty(PropertyDefinition.PLUGINS.name, ""); - this.noPluginsContext = new ConnectionInfo(url, mockDriverDialect, noPluginsProps); + this.noPluginsContext = new ConnectConfig(url, mockDriverDialect, noPluginsProps); Properties pluginsProps = new Properties(); pluginsProps.setProperty(PropertyDefinition.PROFILE_NAME.name, "benchmark"); pluginsProps.setProperty(PropertyDefinition.ENABLE_TELEMETRY.name, "false"); - this.pluginsContext = new ConnectionInfo(url, mockDriverDialect, pluginsProps); + this.pluginsContext = new ConnectConfig(url, mockDriverDialect, pluginsProps); TelemetryFactory telemetryFactory = new DefaultTelemetryFactory(pluginsProps); @@ -197,7 +197,7 @@ public ConnectionPluginManager initConnectionPluginManagerWithPlugins() throws S @Benchmark public Connection connectWithPlugins() throws SQLException { return pluginManager.connect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), true, null); @@ -206,7 +206,7 @@ public Connection connectWithPlugins() throws SQLException { @Benchmark public Connection connectWithNoPlugins() throws SQLException { return pluginManagerWithNoPlugins.connect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), true, null); diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java index 6579f336e..5554ae0fd 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java @@ -63,7 +63,7 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.GaugeCallable; @@ -134,7 +134,7 @@ public void setUpIteration() throws Exception { when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); // noinspection unchecked when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge); - when(mockConnectionProvider.connect(any(ConnectionInfo.class), any(HostSpec.class))).thenReturn(mockConnection); + when(mockConnectionProvider.connect(any(ConnectConfig.class), any(HostSpec.class))).thenReturn(mockConnection); when(mockConnection.createStatement()).thenReturn(mockStatement); when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); when(mockResultSet.next()).thenReturn(true, true, false); diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java index 83227da90..40f76e316 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/testplugin/BenchmarkPlugin.java @@ -36,7 +36,7 @@ import software.amazon.jdbc.OldConnectionSuggestedAction; import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class BenchmarkPlugin implements ConnectionPlugin, CanReleaseResources { final List resources = new ArrayList<>(); @@ -59,7 +59,7 @@ public T execute(Class resultClass, Class excepti @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -70,7 +70,7 @@ public Connection connect( @Override public Connection forceConnect( - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { LOGGER.finer(() -> String.format("forceConnect=''%s''", hostSpec.getHost())); @@ -97,10 +97,10 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin @Override public void initHostProvider( - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, HostListProviderService hostListProviderService, JdbcCallable initHostProviderFunc) { - LOGGER.finer(() -> String.format("initHostProvider=''%s''", connectionInfo.getInitialConnectionString())); + LOGGER.finer(() -> String.format("initHostProvider=''%s''", connectConfig.getInitialConnectionString())); resources.add("initHostProvider"); } diff --git a/docs/development-guide/LoadablePlugins.md b/docs/development-guide/LoadablePlugins.md index 42c4b97e3..bd44f1342 100644 --- a/docs/development-guide/LoadablePlugins.md +++ b/docs/development-guide/LoadablePlugins.md @@ -118,7 +118,7 @@ public class BadPlugin extends AbstractConnectionPlugin { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -172,7 +172,7 @@ public class GoodExample extends AbstractConnectionPlugin { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java index 9f5f4efbc..9c055f407 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java @@ -34,7 +34,7 @@ import software.amazon.jdbc.targetdriverdialect.ConnectInfo; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.SlidingExpirationCache; public class C3P0PooledConnectionProvider implements PooledConnectionProvider, CanReleaseResources { @@ -55,7 +55,7 @@ public class C3P0PooledConnectionProvider implements PooledConnectionProvider, C protected static final long poolExpirationCheckNanos = TimeUnit.MINUTES.toNanos(30); @Override - public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec) { return true; } @@ -79,14 +79,14 @@ public HostSpec getHostSpecByStrategy(@NonNull List hosts, @NonNull Ho @Override public Connection connect( - @NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) throws SQLException { - Dialect dialect = connectionInfo.getDbDialect(); - Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); - dialect.prepareConnectProperties(propsCopy, connectionInfo.getProtocol(), hostSpec); + @NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec) throws SQLException { + Dialect dialect = connectConfig.getDbDialect(); + Properties propsCopy = PropertyUtils.copyProperties(connectConfig.getProps()); + dialect.prepareConnectProperties(propsCopy, connectConfig.getProtocol(), hostSpec); final ComboPooledDataSource ds = databasePools.computeIfAbsent( hostSpec.getUrl(), - (key) -> createDataSource(connectionInfo, hostSpec, propsCopy), + (key) -> createDataSource(connectConfig, hostSpec, propsCopy), poolExpirationCheckNanos ); @@ -96,14 +96,14 @@ public Connection connect( } protected ComboPooledDataSource createDataSource( - @NonNull ConnectionInfo connectionInfo, + @NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec, @NonNull Properties props) { ConnectInfo connectInfo; try { - connectInfo = connectionInfo.getDriverDialect() - .prepareConnectInfo(connectionInfo.getProtocol(), hostSpec, props); + connectInfo = connectConfig.getDriverDialect() + .prepareConnectInfo(connectConfig.getProtocol(), hostSpec, props); } catch (SQLException ex) { throw new RuntimeException(ex); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java index f0b0c8aa8..3a815960a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPlugin.java @@ -22,7 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; /** * Interface for connection plugins. This class implements ways to execute a JDBC method and to clean up resources used @@ -45,7 +45,7 @@ T execute( * Establishes a connection to the given host using the given driver protocol and properties. If a * non-default {@link ConnectionProvider} has been set with * {@link Driver#setCustomConnectionProvider(ConnectionProvider)} and - * {@link ConnectionProvider#acceptsUrl(ConnectionInfo, HostSpec)} returns true for the given + * {@link ConnectionProvider#acceptsUrl(ConnectConfig, HostSpec)} returns true for the given * protocol, host, and properties, the connection will be created by the non-default * ConnectionProvider. Otherwise, the connection will be created by the default * ConnectionProvider. The default ConnectionProvider will be {@link DriverConnectionProvider} for @@ -53,7 +53,7 @@ T execute( * {@link DataSourceConnectionProvider} for connections requested via an * {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionInfo the connection info for the original connection + * @param connectConfig the connection info for the original connection * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -65,7 +65,7 @@ T execute( * host */ Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) @@ -80,7 +80,7 @@ Connection connect( * requested via the {@link java.sql.DriverManager} and {@link DataSourceConnectionProvider} for * connections requested via an {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -92,7 +92,7 @@ Connection connect( * host */ Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) @@ -132,7 +132,7 @@ HostSpec getHostSpecByStrategy(final List hosts, final HostRole role, throws SQLException, UnsupportedOperationException; void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException; diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 8c1b64da7..61801a6f2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -52,7 +52,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; @@ -371,7 +371,7 @@ public T execute( * Establishes a connection to the given host using the given driver protocol and properties. If a * non-default {@link ConnectionProvider} has been set with * {@link Driver#setCustomConnectionProvider(ConnectionProvider)} and - * {@link ConnectionProvider#acceptsUrl(ConnectionInfo, HostSpec)} returns true for the given + * {@link ConnectionProvider#acceptsUrl(ConnectConfig, HostSpec)} returns true for the given * protocol, host, and properties, the connection will be created by the non-default * ConnectionProvider. Otherwise, the connection will be created by the default * ConnectionProvider. The default ConnectionProvider will be {@link DriverConnectionProvider} for @@ -379,7 +379,7 @@ public T execute( * {@link DataSourceConnectionProvider} for connections requested via an * {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionInfo the connection info for the original connection + * @param connectConfig the connection info for the original connection * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -390,7 +390,7 @@ public T execute( * host */ public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final @Nullable ConnectionPlugin pluginToSkip) @@ -403,7 +403,7 @@ public Connection connect( return executeWithSubscribedPlugins( JdbcMethod.CONNECT, (plugin, func) -> - plugin.connect(connectionInfo, hostSpec, isInitialConnection, func), + plugin.connect(connectConfig, hostSpec, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); }, @@ -428,7 +428,7 @@ public Connection connect( * requested via the {@link java.sql.DriverManager} and {@link DataSourceConnectionProvider} for * connections requested via an {@link software.amazon.jdbc.ds.AwsWrapperDataSource}. * - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param hostSpec the host details for the desired connection * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has @@ -439,7 +439,7 @@ public Connection connect( * host */ public Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final @Nullable ConnectionPlugin pluginToSkip) @@ -449,7 +449,7 @@ public Connection forceConnect( return executeWithSubscribedPlugins( JdbcMethod.FORCECONNECT, (plugin, func) -> - plugin.forceConnect(connectionInfo, hostSpec, isInitialConnection, func), + plugin.forceConnect(connectConfig, hostSpec, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); }, @@ -550,7 +550,7 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin } public void initHostProvider( - final ConnectionInfo connectionInfo, final HostListProviderService hostListProviderService) + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService) throws SQLException { TelemetryContext context = this.telemetryFactory.openTelemetryContext( "initHostProvider", TelemetryTraceLevel.NESTED); @@ -560,7 +560,7 @@ public void initHostProvider( JdbcMethod.INITHOSTPROVIDER, (PluginPipeline) (plugin, func) -> { - plugin.initHostProvider(connectionInfo, hostListProviderService, func); + plugin.initHostProvider(connectConfig, hostListProviderService, func); return null; }, () -> { diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java index 46a1ee00b..b2980e09d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProvider.java @@ -22,7 +22,7 @@ import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; /** * Implement this interface in order to handle the physical connection creation process. @@ -33,12 +33,12 @@ public interface ConnectionProvider { * properties. Some ConnectionProvider implementations may not be able to handle certain URL * types or properties. * - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param hostSpec the HostSpec containing the host-port information for the host to connect to * @return true if this ConnectionProvider can provide connections for the given URL, otherwise * return false */ - boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec); + boolean acceptsUrl(@NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec); /** * Indicates whether the selection strategy is supported by the connection provider. @@ -70,12 +70,12 @@ HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param hostSpec the HostSpec containing the host-port information for the host to connect to * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ - Connection connect(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) throws SQLException; + Connection connect(@NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec) throws SQLException; String getTargetName(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java index f99f6051b..6ea63110a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionProviderManager.java @@ -23,7 +23,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.cleanup.CanReleaseResources; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class ConnectionProviderManager { @@ -66,19 +66,19 @@ public static void setConnectionProvider(ConnectionProvider connProvider) { * non-default ConnectionProvider will be returned. Otherwise, the default ConnectionProvider will * be returned. See {@link ConnectionProvider#acceptsUrl} for more info. * - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param host the host info for the connection that will be established * @return the {@link ConnectionProvider} to use to establish a connection using the given driver * protocol, host details, and properties */ - public ConnectionProvider getConnectionProvider(ConnectionInfo connectionInfo, HostSpec host) { + public ConnectionProvider getConnectionProvider(ConnectConfig connectConfig, HostSpec host) { final ConnectionProvider customConnectionProvider = Driver.getCustomConnectionProvider(); - if (customConnectionProvider != null && customConnectionProvider.acceptsUrl(connectionInfo, host)) { + if (customConnectionProvider != null && customConnectionProvider.acceptsUrl(connectConfig, host)) { return customConnectionProvider; } - if (this.effectiveConnProvider != null && this.effectiveConnProvider.acceptsUrl(connectionInfo, host)) { + if (this.effectiveConnProvider != null && this.effectiveConnProvider.acceptsUrl(connectConfig, host)) { return this.effectiveConnProvider; } @@ -207,7 +207,7 @@ public static void resetConnectionInitFunc() { public void initConnection( final @Nullable Connection connection, - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec) throws SQLException { final ConnectionInitFunc connectionInitFunc = Driver.getConnectionInitFunc(); @@ -215,13 +215,13 @@ public void initConnection( return; } - connectionInitFunc.initConnection(connection, connectionInfo, hostSpec); + connectionInitFunc.initConnection(connection, connectConfig, hostSpec); } public interface ConnectionInitFunc { void initConnection( final @Nullable Connection connection, - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec) throws SQLException; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java index b1342ab3c..b4defa94f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DataSourceConnectionProvider.java @@ -36,7 +36,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; /** * This class is a basic implementation of {@link ConnectionProvider} interface. It creates and @@ -67,7 +67,7 @@ public DataSourceConnectionProvider(final @NonNull DataSource dataSource) { } @Override - public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec) { return true; } @@ -93,16 +93,16 @@ public HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param hostSpec The HostSpec containing the host-port information for the host to connect to * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ @Override public Connection connect( - final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec) throws SQLException { - final Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); - connectionInfo.getDbDialect().prepareConnectProperties(propsCopy, connectionInfo.getProtocol(), hostSpec); + final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec) throws SQLException { + final Properties propsCopy = PropertyUtils.copyProperties(connectConfig.getProps()); + connectConfig.getDbDialect().prepareConnectProperties(propsCopy, connectConfig.getProtocol(), hostSpec); Connection conn; @@ -111,7 +111,7 @@ public Connection connect( LOGGER.finest(() -> "Use a separate DataSource object to create a connection."); // use a new data source instance to instantiate a connection final DataSource ds = createDataSource(); - conn = this.openConnection(ds, connectionInfo, hostSpec, propsCopy); + conn = this.openConnection(ds, connectConfig, hostSpec, propsCopy); } else { @@ -120,7 +120,7 @@ public Connection connect( this.lock.lock(); LOGGER.finest(() -> "Use main DataSource object to create a connection."); try { - conn = this.openConnection(this.dataSource, connectionInfo, hostSpec, propsCopy); + conn = this.openConnection(this.dataSource, connectConfig, hostSpec, propsCopy); } finally { this.lock.unlock(); } @@ -135,15 +135,15 @@ public Connection connect( protected Connection openConnection( final @NonNull DataSource ds, - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec, final @NonNull Properties props) throws SQLException { final boolean enableGreenNodeReplacement = PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(props); try { - connectionInfo.getDriverDialect().prepareDataSource( + connectConfig.getDriverDialect().prepareDataSource( ds, - connectionInfo.getProtocol(), + connectConfig.getProtocol(), hostSpec, props); return ds.getConnection(); @@ -192,9 +192,9 @@ protected Connection openConnection( .host(fixedHost) .build(); - connectionInfo.getDriverDialect().prepareDataSource( + connectConfig.getDriverDialect().prepareDataSource( this.dataSource, - connectionInfo.getProtocol(), + connectConfig.getProtocol(), connectionHostSpec, props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java index 729bad685..306b8b5f2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java @@ -33,7 +33,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; /** * This class is a basic implementation of {@link ConnectionProvider} interface. It creates and @@ -63,7 +63,7 @@ public DriverConnectionProvider(final java.sql.Driver driver) { } @Override - public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec) { return true; } @@ -89,19 +89,19 @@ public HostSpec getHostSpecByStrategy( /** * Called once per connection that needs to be created. * - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param hostSpec The HostSpec containing the host-port information for the host to connect to * @return {@link Connection} resulting from the given connection information * @throws SQLException if an error occurs */ @Override - public Connection connect(final @NonNull ConnectionInfo connectionInfo, final @NonNull HostSpec hostSpec) + public Connection connect(final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec) throws SQLException { - final Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); + final Properties propsCopy = PropertyUtils.copyProperties(connectConfig.getProps()); final ConnectInfo connectInfo = - connectionInfo.getDriverDialect().prepareConnectInfo(connectionInfo.getProtocol(), hostSpec, propsCopy); + connectConfig.getDriverDialect().prepareConnectInfo(connectConfig.getProtocol(), hostSpec, propsCopy); - connectionInfo.getDbDialect().prepareConnectProperties(propsCopy, connectionInfo.getProtocol(), hostSpec); + connectConfig.getDbDialect().prepareConnectProperties(propsCopy, connectConfig.getProtocol(), hostSpec); LOGGER.finest(() -> "Connecting to " + connectInfo.url + PropertyUtils.logProperties( PropertyUtils.maskProperties(connectInfo.props), @@ -158,8 +158,8 @@ public Connection connect(final @NonNull ConnectionInfo connectionInfo, final @N .host(fixedHost) .build(); - final ConnectInfo fixedConnectInfo = connectionInfo.getDriverDialect().prepareConnectInfo( - connectionInfo.getProtocol(), connectionHostSpec, propsCopy); + final ConnectInfo fixedConnectInfo = connectConfig.getDriverDialect().prepareConnectInfo( + connectConfig.getProtocol(), connectionHostSpec, propsCopy); LOGGER.finest(() -> "Connecting to " + fixedConnectInfo.url + " after correcting the hostname from " + originalHost diff --git a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java index af0100f54..032924c65 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.SlidingExpirationCache; public class HikariPooledConnectionProvider implements PooledConnectionProvider, @@ -203,9 +203,9 @@ public HikariPooledConnectionProvider( @Override - public boolean acceptsUrl(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) { + public boolean acceptsUrl(@NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec) { if (this.acceptsUrlFunc != null) { - return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectionInfo.getProps()); + return this.acceptsUrlFunc.acceptsUrl(hostSpec, connectConfig.getProps()); } final RdsUrlType urlType = rdsUtils.identifyRdsType(hostSpec.getHost()); @@ -238,9 +238,9 @@ public HostSpec getHostSpecByStrategy( } @Override - public Connection connect(@NonNull ConnectionInfo connectionInfo, @NonNull HostSpec hostSpec) + public Connection connect(@NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec) throws SQLException { - final Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); + final Properties propsCopy = PropertyUtils.copyProperties(connectConfig.getProps()); HostSpec connectionHostSpec = hostSpec; if (PropertyDefinition.ENABLE_GREEN_NODE_REPLACEMENT.getBoolean(propsCopy) @@ -267,12 +267,12 @@ public Connection connect(@NonNull ConnectionInfo connectionInfo, @NonNull HostS } final HostSpec finalHostSpec = connectionHostSpec; - connectionInfo.getDbDialect().prepareConnectProperties( - propsCopy, connectionInfo.getProtocol(), finalHostSpec); + connectConfig.getDbDialect().prepareConnectProperties( + propsCopy, connectConfig.getProtocol(), finalHostSpec); final HikariDataSource ds = (HikariDataSource) HikariPoolsHolder.databasePools.computeIfAbsent( Pair.create(hostSpec.getUrl(), getPoolKey(finalHostSpec, propsCopy)), - (lambdaPoolKey) -> createHikariDataSource(connectionInfo, finalHostSpec, propsCopy), + (lambdaPoolKey) -> createHikariDataSource(connectConfig, finalHostSpec, propsCopy), poolExpirationCheckNanos ); @@ -306,21 +306,21 @@ public void releaseResources() { * HikariConfig passed to this method should be created via a * {@link HikariPoolConfigurator}, which allows the user to specify any * additional configuration properties. - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param hostSpec the host details used to form the connection * @param connectionProps the connection properties */ protected void configurePool( final HikariConfig config, - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final Properties connectionProps) { final Properties copy = PropertyUtils.copyProperties(connectionProps); ConnectInfo connectInfo; try { - connectInfo = connectionInfo.getDriverDialect().prepareConnectInfo( - connectionInfo.getProtocol(), hostSpec, copy); + connectInfo = connectConfig.getDriverDialect().prepareConnectInfo( + connectConfig.getProtocol(), hostSpec, copy); } catch (SQLException ex) { throw new RuntimeException(ex); } @@ -406,12 +406,12 @@ public void logConnections() { } HikariDataSource createHikariDataSource( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final Properties props) { HikariConfig config = poolConfigurator.configurePool(hostSpec, props); - configurePool(config, connectionInfo, hostSpec, props); + configurePool(config, connectConfig, hostSpec, props); return new HikariDataSource(config); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java index 6ff9e9591..ebf26efc4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PartialPluginService.java @@ -48,7 +48,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.CacheMap; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -66,7 +66,7 @@ public class PartialPluginService implements PluginService, CanReleaseResources, protected static final CacheMap hostAvailabilityExpiringCache = new CacheMap<>(); protected final FullServicesContainer servicesContainer; - protected final ConnectionInfo connectionInfo; + protected final ConnectConfig connectConfig; protected final ConnectionPluginManager pluginManager; protected volatile HostListProvider hostListProvider; protected List allHosts = new ArrayList<>(); @@ -79,18 +79,18 @@ public class PartialPluginService implements PluginService, CanReleaseResources, protected final ConnectionProviderManager connectionProviderManager; public PartialPluginService( - @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionInfo connectionInfo) { + @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectConfig connectConfig) { this( servicesContainer, new ExceptionManager(), - connectionInfo, + connectConfig, null); } public PartialPluginService( @NonNull final FullServicesContainer servicesContainer, @NonNull final ExceptionManager exceptionManager, - @NonNull final ConnectionInfo connectionInfo, + @NonNull final ConnectConfig connectConfig, @Nullable final ConfigurationProfile configurationProfile) { this.servicesContainer = servicesContainer; this.servicesContainer.setHostListProviderService(this); @@ -98,7 +98,7 @@ public PartialPluginService( this.servicesContainer.setPluginManagerService(this); this.pluginManager = servicesContainer.getConnectionPluginManager(); - this.connectionInfo = connectionInfo; + this.connectConfig = connectConfig; this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; @@ -110,8 +110,8 @@ public PartialPluginService( ? this.configurationProfile.getExceptionHandler() : null; - HostListProviderSupplier supplier = this.connectionInfo.getDbDialect().getHostListProvider(); - this.hostListProvider = supplier.getProvider(this.connectionInfo, this.servicesContainer); + HostListProviderSupplier supplier = this.connectConfig.getDbDialect().getHostListProvider(); + this.hostListProvider = supplier.getProvider(this.connectConfig, this.servicesContainer); } @Override @@ -163,13 +163,13 @@ public HostSpec getInitialConnectionHostSpec() { } @Override - public ConnectionInfo getConnectionInfo() { - return this.connectionInfo; + public ConnectConfig getConnectConfig() { + return this.connectConfig; } @Override public String getOriginalUrl() { - return this.connectionInfo.getInitialConnectionString(); + return this.connectConfig.getInitialConnectionString(); } @Override @@ -214,13 +214,13 @@ public ConnectionProvider getDefaultConnectionProvider() { public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = - this.connectionProviderManager.getConnectionProvider(this.connectionInfo, host); + this.connectionProviderManager.getConnectionProvider(this.connectConfig, host); return (connectionProvider instanceof PooledConnectionProvider); } @Override public String getDriverProtocol() { - return this.connectionInfo.getProtocol(); + return this.connectConfig.getProtocol(); } @Override @@ -504,7 +504,7 @@ public Connection forceConnect( final Properties props, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { - return this.pluginManager.forceConnect(this.connectionInfo, hostSpec, true, pluginToSkip); + return this.pluginManager.forceConnect(this.connectConfig, hostSpec, true, pluginToSkip); } private void updateHostAvailability(final List hosts) { @@ -524,7 +524,7 @@ public void releaseResources() { @Override public boolean isNetworkException(Throwable throwable) { - return this.isNetworkException(throwable, this.connectionInfo.getDriverDialect()); + return this.isNetworkException(throwable, this.connectConfig.getDriverDialect()); } @Override @@ -534,7 +534,7 @@ public boolean isNetworkException(final Throwable throwable, @Nullable TargetDri } return this.exceptionManager.isNetworkException( - this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); + this.connectConfig.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -543,12 +543,12 @@ public boolean isNetworkException(final String sqlState) { return this.exceptionHandler.isNetworkException(sqlState); } - return this.exceptionManager.isNetworkException(this.connectionInfo.getDbDialect(), sqlState); + return this.exceptionManager.isNetworkException(this.connectConfig.getDbDialect(), sqlState); } @Override public boolean isLoginException(Throwable throwable) { - return this.isLoginException(throwable, this.connectionInfo.getDriverDialect()); + return this.isLoginException(throwable, this.connectConfig.getDriverDialect()); } @Override @@ -558,7 +558,7 @@ public boolean isLoginException(final Throwable throwable, @Nullable TargetDrive } return this.exceptionManager.isLoginException( - this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); + this.connectConfig.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -566,17 +566,17 @@ public boolean isLoginException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(sqlState); } - return this.exceptionManager.isLoginException(this.connectionInfo.getDbDialect(), sqlState); + return this.exceptionManager.isLoginException(this.connectConfig.getDbDialect(), sqlState); } @Override public Dialect getDialect() { - return this.connectionInfo.getDbDialect(); + return this.connectConfig.getDbDialect(); } @Override public TargetDriverDialect getTargetDriverDialect() { - return this.connectionInfo.getDriverDialect(); + return this.connectConfig.getDriverDialect(); } @Override @@ -624,12 +624,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionInfo.getProps())); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectConfig.getProps())); } @Override public Properties getProperties() { - return this.connectionInfo.getProps(); + return this.connectConfig.getProps(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index 13e60c794..f84bb46d5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -29,7 +29,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.states.SessionStateService; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryFactory; /** @@ -83,11 +83,11 @@ EnumSet setCurrentConnection( HostSpec getInitialConnectionHostSpec(); /** - * Get the {@link ConnectionInfo} for the current original connection. + * Get the {@link ConnectConfig} for the current original connection. * - * @return the {@link ConnectionInfo} for the current original connection. + * @return the {@link ConnectConfig} for the current original connection. */ - ConnectionInfo getConnectionInfo(); + ConnectConfig getConnectConfig(); String getOriginalUrl(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 3a1779b2a..3fe66f6ff 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -53,7 +53,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.CacheMap; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -70,7 +70,7 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, protected static final long DEFAULT_STATUS_CACHE_EXPIRE_NANO = TimeUnit.MINUTES.toNanos(60); protected final ConnectionPluginManager pluginManager; - protected final ConnectionInfo connectionInfo; + protected final ConnectConfig connectConfig; protected volatile HostListProvider hostListProvider; protected List allHosts = new ArrayList<>(); protected Connection currentConnection; @@ -88,19 +88,19 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, protected final ReentrantLock connectionSwitchLock = new ReentrantLock(); public PluginServiceImpl( - @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectionInfo connectionInfo) + @NonNull final FullServicesContainer servicesContainer, @NonNull final ConnectConfig connectConfig) throws SQLException { - this(servicesContainer, new ExceptionManager(), connectionInfo, null, null, null); + this(servicesContainer, new ExceptionManager(), connectConfig, null, null, null); } public PluginServiceImpl( @NonNull final FullServicesContainer servicesContainer, - @NonNull final ConnectionInfo connectionInfo, + @NonNull final ConnectConfig connectConfig, @Nullable final ConfigurationProfile configurationProfile) throws SQLException { this( servicesContainer, new ExceptionManager(), - connectionInfo, + connectConfig, null, configurationProfile, null); @@ -109,13 +109,13 @@ public PluginServiceImpl( public PluginServiceImpl( @NonNull final FullServicesContainer servicesContainer, @NonNull final ExceptionManager exceptionManager, - @NonNull final ConnectionInfo connectionInfo, + @NonNull final ConnectConfig connectConfig, @Nullable final DialectProvider dialectProvider, @Nullable final ConfigurationProfile configurationProfile, @Nullable final SessionStateService sessionStateService) throws SQLException { this.servicesContainer = servicesContainer; this.pluginManager = servicesContainer.getConnectionPluginManager(); - this.connectionInfo = connectionInfo; + this.connectConfig = connectConfig; this.configurationProfile = configurationProfile; this.exceptionManager = exceptionManager; this.dialectProvider = dialectProvider != null ? dialectProvider : new DialectManager(this); @@ -125,7 +125,7 @@ public PluginServiceImpl( this.sessionStateService = sessionStateService != null ? sessionStateService - : new SessionStateServiceImpl(this, this.connectionInfo.getProps()); + : new SessionStateServiceImpl(this, this.connectConfig.getProps()); this.exceptionHandler = this.configurationProfile != null && this.configurationProfile.getExceptionHandler() != null ? this.configurationProfile.getExceptionHandler() @@ -133,8 +133,8 @@ public PluginServiceImpl( Dialect dialect = this.configurationProfile != null && this.configurationProfile.getDialect() != null ? this.configurationProfile.getDialect() - : this.dialectProvider.getDialect(this.connectionInfo); - this.connectionInfo.setDbDialect(dialect); + : this.dialectProvider.getDialect(this.connectConfig); + this.connectConfig.setDbDialect(dialect); } @Override @@ -185,13 +185,13 @@ public HostSpec getInitialConnectionHostSpec() { } @Override - public ConnectionInfo getConnectionInfo() { - return this.connectionInfo; + public ConnectConfig getConnectConfig() { + return this.connectConfig; } @Override public String getOriginalUrl() { - return this.connectionInfo.getInitialConnectionString(); + return this.connectConfig.getInitialConnectionString(); } @Override @@ -233,13 +233,13 @@ public ConnectionProvider getDefaultConnectionProvider() { public boolean isPooledConnectionProvider(HostSpec host, Properties props) { final ConnectionProvider connectionProvider = - this.connectionProviderManager.getConnectionProvider(this.connectionInfo, host); + this.connectionProviderManager.getConnectionProvider(this.connectConfig, host); return (connectionProvider instanceof PooledConnectionProvider); } @Override public String getDriverProtocol() { - return this.getConnectionInfo().getProtocol(); + return this.getConnectConfig().getProtocol(); } @Override @@ -291,7 +291,7 @@ public EnumSet setCurrentConnection( this.setInTransaction(false); if (isInTransaction - && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectionInfo.getProps())) { + && PropertyDefinition.ROLLBACK_ON_SWITCH.getBoolean(this.connectConfig.getProps())) { try { oldConnection.rollback(); } catch (final SQLException e) { @@ -592,7 +592,7 @@ public Connection connect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.connect( - this.connectionInfo, hostSpec, this.currentConnection == null, pluginToSkip); + this.connectConfig, hostSpec, this.currentConnection == null, pluginToSkip); } @Override @@ -610,7 +610,7 @@ public Connection forceConnect( final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { return this.pluginManager.forceConnect( - this.connectionInfo, hostSpec, this.currentConnection == null, pluginToSkip); + this.connectConfig, hostSpec, this.currentConnection == null, pluginToSkip); } private void updateHostAvailability(final List hosts) { @@ -643,7 +643,7 @@ public void releaseResources() { @Override @Deprecated public boolean isNetworkException(Throwable throwable) { - return this.isNetworkException(throwable, this.connectionInfo.getDriverDialect()); + return this.isNetworkException(throwable, this.connectConfig.getDriverDialect()); } @Override @@ -651,7 +651,7 @@ public boolean isNetworkException(final Throwable throwable, @Nullable TargetDri if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(throwable, targetDriverDialect); } - return this.exceptionManager.isNetworkException(this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); + return this.exceptionManager.isNetworkException(this.connectConfig.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -659,13 +659,13 @@ public boolean isNetworkException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isNetworkException(sqlState); } - return this.exceptionManager.isNetworkException(this.connectionInfo.getDbDialect(), sqlState); + return this.exceptionManager.isNetworkException(this.connectConfig.getDbDialect(), sqlState); } @Override @Deprecated public boolean isLoginException(Throwable throwable) { - return this.isLoginException(throwable, this.connectionInfo.getDriverDialect()); + return this.isLoginException(throwable, this.connectConfig.getDriverDialect()); } @Override @@ -673,7 +673,7 @@ public boolean isLoginException(final Throwable throwable, @Nullable TargetDrive if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(throwable, targetDriverDialect); } - return this.exceptionManager.isLoginException(this.connectionInfo.getDbDialect(), throwable, targetDriverDialect); + return this.exceptionManager.isLoginException(this.connectConfig.getDbDialect(), throwable, targetDriverDialect); } @Override @@ -681,30 +681,30 @@ public boolean isLoginException(final String sqlState) { if (this.exceptionHandler != null) { return this.exceptionHandler.isLoginException(sqlState); } - return this.exceptionManager.isLoginException(this.connectionInfo.getDbDialect(), sqlState); + return this.exceptionManager.isLoginException(this.connectConfig.getDbDialect(), sqlState); } @Override public Dialect getDialect() { - return this.connectionInfo.getDbDialect(); + return this.connectConfig.getDbDialect(); } @Override public TargetDriverDialect getTargetDriverDialect() { - return this.connectionInfo.getDriverDialect(); + return this.connectConfig.getDriverDialect(); } public void updateDialect(final @NonNull Connection connection) throws SQLException { - final Dialect originalDialect = this.connectionInfo.getDbDialect(); + final Dialect originalDialect = this.connectConfig.getDbDialect(); Dialect currentDialect = this.dialectProvider.getDialect( - this.connectionInfo.getInitialConnectionString(), this.initialConnectionHostSpec, connection); + this.connectConfig.getInitialConnectionString(), this.initialConnectionHostSpec, connection); if (originalDialect == currentDialect) { return; } - this.connectionInfo.setDbDialect(currentDialect); - final HostListProviderSupplier supplier = this.connectionInfo.getDbDialect().getHostListProvider(); - this.setHostListProvider(supplier.getProvider(this.connectionInfo, this.servicesContainer)); + this.connectConfig.setDbDialect(currentDialect); + final HostListProviderSupplier supplier = this.connectConfig.getDbDialect().getHostListProvider(); + this.setHostListProvider(supplier.getProvider(this.connectConfig, this.servicesContainer)); this.refreshHostList(connection); } @@ -747,12 +747,12 @@ public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLExce @Override public HostSpecBuilder getHostSpecBuilder() { - return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectionInfo.getProps())); + return new HostSpecBuilder(new HostAvailabilityStrategyFactory().create(this.connectConfig.getProps())); } @Override public Properties getProperties() { - return this.connectionInfo.getProps(); + return this.connectConfig.getProps(); } public TelemetryFactory getTelemetryFactory() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java index 5d3675674..247981839 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java @@ -89,11 +89,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> { + return (connectConfig, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsHostListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -102,7 +102,7 @@ public HostListProviderSupplier getHostListProvider() { } return new AuroraHostListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java index 34d589c2a..4499db901 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java @@ -135,11 +135,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> { + return (connectConfig, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsHostListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -148,7 +148,7 @@ public HostListProviderSupplier getHostListProvider() { } return new AuroraHostListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java index 075ca0687..ec63ae39d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java @@ -35,7 +35,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.CacheMap; public class DialectManager implements DialectProvider { @@ -119,7 +119,7 @@ public static void resetEndpointCache() { } @Override - public Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws SQLException { + public Dialect getDialect(final @NonNull ConnectConfig connectConfig) throws SQLException { this.canUpdate = false; this.dialect = null; @@ -131,10 +131,10 @@ public Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws S return this.dialect; } - final String userDialectSetting = DIALECT.getString(connectionInfo.getProps()); + final String userDialectSetting = DIALECT.getString(connectConfig.getProps()); final String dialectCode = !StringUtils.isNullOrEmpty(userDialectSetting) ? userDialectSetting - : knownEndpointDialects.get(connectionInfo.getInitialConnectionString()); + : knownEndpointDialects.get(connectConfig.getInitialConnectionString()); if (!StringUtils.isNullOrEmpty(dialectCode)) { final Dialect userDialect = knownDialectsByCode.get(dialectCode); @@ -149,18 +149,18 @@ public Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws S } } - if (StringUtils.isNullOrEmpty(connectionInfo.getProtocol())) { + if (StringUtils.isNullOrEmpty(connectConfig.getProtocol())) { throw new IllegalArgumentException("protocol"); } - String connectionString = connectionInfo.getInitialConnectionString(); + String connectionString = connectConfig.getInitialConnectionString(); final List hosts = this.connectionUrlParser.getHostsFromConnectionUrl( - connectionInfo.getInitialConnectionString(), true, pluginService::getHostSpecBuilder); + connectConfig.getInitialConnectionString(), true, pluginService::getHostSpecBuilder); if (!Utils.isNullOrEmpty(hosts)) { connectionString = hosts.get(0).getHost(); } - if (connectionInfo.getProtocol().contains("mysql")) { + if (connectConfig.getProtocol().contains("mysql")) { RdsUrlType type = this.rdsHelper.identifyRdsType(connectionString); if (type.isRdsCluster()) { this.canUpdate = true; @@ -182,7 +182,7 @@ public Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws S return this.dialect; } - if (connectionInfo.getProtocol().contains("postgresql")) { + if (connectConfig.getProtocol().contains("postgresql")) { RdsUrlType type = this.rdsHelper.identifyRdsType(connectionString); if (RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP.equals(type)) { this.canUpdate = false; @@ -210,7 +210,7 @@ public Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws S return this.dialect; } - if (connectionInfo.getProtocol().contains("mariadb")) { + if (connectConfig.getProtocol().contains("mariadb")) { this.canUpdate = true; this.dialectCode = DialectCodes.MARIADB; this.dialect = knownDialectsByCode.get(DialectCodes.MARIADB); diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java index 384aeb22d..910124235 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java @@ -20,10 +20,10 @@ import java.sql.SQLException; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public interface DialectProvider { - Dialect getDialect(final @NonNull ConnectionInfo connectionInfo) throws SQLException; + Dialect getDialect(final @NonNull ConnectConfig connectConfig) throws SQLException; Dialect getDialect( final @NonNull String originalUrl, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java index ad6ddbc19..d9ce81918 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/HostListProviderSupplier.java @@ -19,11 +19,11 @@ import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.HostListProvider; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; @FunctionalInterface public interface HostListProviderSupplier { @NonNull HostListProvider getProvider( - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull FullServicesContainer servicesContainer); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java index 4bd593c34..6a5a6e0f2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java @@ -104,8 +104,8 @@ public List getDialectUpdateCandidates() { } public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> - new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); + return (connectConfig, servicesContainer) -> + new ConnectionStringHostListProvider(connectConfig, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java index db5e7556c..bf549dcd8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java @@ -105,8 +105,8 @@ public List getDialectUpdateCandidates() { } public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> - new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); + return (connectConfig, servicesContainer) -> + new ConnectionStringHostListProvider(connectConfig, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java index ebd48c64d..a02094553 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java @@ -106,8 +106,8 @@ public List getDialectUpdateCandidates() { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> - new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); + return (connectConfig, servicesContainer) -> + new ConnectionStringHostListProvider(connectConfig, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java index 7ac9bcc10..e8eed7f33 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java @@ -94,11 +94,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> { + return (connectConfig, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsMultiAzHostListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -108,7 +108,7 @@ public HostListProviderSupplier getHostListProvider() { } else { return new RdsMultiAzDbClusterListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java index fcf94826e..fbdeced04 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java @@ -80,11 +80,11 @@ public boolean isDialect(final Connection connection) { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> { + return (connectConfig, servicesContainer) -> { final PluginService pluginService = servicesContainer.getPluginService(); if (pluginService.isPluginInUse(FailoverConnectionPlugin.class)) { return new MonitoringRdsMultiAzHostListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, @@ -95,7 +95,7 @@ public HostListProviderSupplier getHostListProvider() { } else { return new RdsMultiAzDbClusterListProvider( - connectionInfo, + connectConfig, servicesContainer, TOPOLOGY_QUERY, NODE_ID_QUERY, diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java index 7350a52ca..9fdcb3c79 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java @@ -81,8 +81,8 @@ public List getDialectUpdateCandidates() { @Override public HostListProviderSupplier getHostListProvider() { - return (connectionInfo, servicesContainer) -> - new ConnectionStringHostListProvider(connectionInfo, servicesContainer.getHostListProviderService()); + return (connectConfig, servicesContainer) -> + new ConnectionStringHostListProvider(connectConfig, servicesContainer.getHostListProviderService()); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java index 12cb429e1..3c76d93ae 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java @@ -19,7 +19,7 @@ import java.util.logging.Logger; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class AuroraHostListProvider extends RdsHostListProvider { @@ -27,13 +27,13 @@ public class AuroraHostListProvider extends RdsHostListProvider { static final Logger LOGGER = Logger.getLogger(AuroraHostListProvider.class.getName()); public AuroraHostListProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery) { super( - connectionInfo, + connectConfig, servicesContainer, topologyQuery, nodeIdQuery, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java index 09239d6d2..7cdfe78c6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java @@ -29,7 +29,7 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.util.ConnectionUrlParser; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class ConnectionStringHostListProvider implements StaticHostListProvider { @@ -39,7 +39,7 @@ public class ConnectionStringHostListProvider implements StaticHostListProvider private boolean isInitialized = false; private final boolean isSingleWriterConnectionString; private final ConnectionUrlParser connectionUrlParser; - private final ConnectionInfo connectionInfo; + private final ConnectConfig connectConfig; private final HostListProviderService hostListProviderService; public static final AwsWrapperProperty SINGLE_WRITER_CONNECTION_STRING = @@ -50,17 +50,17 @@ public class ConnectionStringHostListProvider implements StaticHostListProvider + "cluster has only one writer. The writer must be the first host in the connection string"); public ConnectionStringHostListProvider( - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull HostListProviderService hostListProviderService) { - this(connectionInfo, hostListProviderService, new ConnectionUrlParser()); + this(connectConfig, hostListProviderService, new ConnectionUrlParser()); } ConnectionStringHostListProvider( - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull HostListProviderService hostListProviderService, final @NonNull ConnectionUrlParser connectionUrlParser) { - this.connectionInfo = connectionInfo; - this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectionInfo.getProps()); + this.connectConfig = connectConfig; + this.isSingleWriterConnectionString = SINGLE_WRITER_CONNECTION_STRING.getBoolean(connectConfig.getProps()); this.connectionUrlParser = connectionUrlParser; this.hostListProviderService = hostListProviderService; } @@ -71,12 +71,12 @@ private void init() throws SQLException { } this.hostList.addAll( this.connectionUrlParser.getHostsFromConnectionUrl( - this.connectionInfo.getInitialConnectionString(), + this.connectConfig.getInitialConnectionString(), this.isSingleWriterConnectionString, () -> this.hostListProviderService.getHostSpecBuilder())); if (this.hostList.isEmpty()) { throw new SQLException(Messages.get("ConnectionStringHostListProvider.parsedListEmpty", - new Object[] {this.connectionInfo.getInitialConnectionString()})); + new Object[] {this.connectConfig.getInitialConnectionString()})); } this.hostListProviderService.setInitialConnectionHostSpec(this.hostList.get(0)); this.isInitialized = true; diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java index 97800cbe1..77d5252e7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java @@ -55,7 +55,7 @@ import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.SynchronousExecutor; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.CacheMap; public class RdsHostListProvider implements DynamicHostListProvider { @@ -95,7 +95,7 @@ public class RdsHostListProvider implements DynamicHostListProvider { protected final FullServicesContainer servicesContainer; protected final HostListProviderService hostListProviderService; - protected final ConnectionInfo connectionInfo; + protected final ConnectConfig connectConfig; protected final String topologyQuery; protected final String nodeIdQuery; protected final String isReaderQuery; @@ -122,12 +122,12 @@ public class RdsHostListProvider implements DynamicHostListProvider { } public RdsHostListProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery) { - this.connectionInfo = connectionInfo; + this.connectConfig = connectConfig; this.servicesContainer = servicesContainer; this.hostListProviderService = servicesContainer.getHostListProviderService(); this.topologyQuery = topologyQuery; @@ -148,18 +148,18 @@ protected void init() throws SQLException { // initial topology is based on connection string this.initialHostList = - connectionUrlParser.getHostsFromConnectionUrl(this.connectionInfo.getInitialConnectionString(), false, + connectionUrlParser.getHostsFromConnectionUrl(this.connectConfig.getInitialConnectionString(), false, this.hostListProviderService::getHostSpecBuilder); if (this.initialHostList == null || this.initialHostList.isEmpty()) { throw new SQLException(Messages.get("RdsHostListProvider.parsedListEmpty", - new Object[] {this.connectionInfo.getInitialConnectionString()})); + new Object[] {this.connectConfig.getInitialConnectionString()})); } this.initialHostSpec = this.initialHostList.get(0); this.hostListProviderService.setInitialConnectionHostSpec(this.initialHostSpec); this.clusterId = UUID.randomUUID().toString(); this.isPrimaryClusterId = false; - Properties props = this.connectionInfo.getProps(); + Properties props = this.connectConfig.getProps(); this.refreshRateNano = TimeUnit.MILLISECONDS.toNanos(CLUSTER_TOPOLOGY_REFRESH_RATE_MS.getInteger(props)); diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java index a32748ef1..761344458 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProvider.java @@ -32,7 +32,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class RdsMultiAzDbClusterListProvider extends RdsHostListProvider { private final String fetchWriterNodeQuery; @@ -40,7 +40,7 @@ public class RdsMultiAzDbClusterListProvider extends RdsHostListProvider { static final Logger LOGGER = Logger.getLogger(RdsMultiAzDbClusterListProvider.class.getName()); public RdsMultiAzDbClusterListProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, @@ -49,7 +49,7 @@ public RdsMultiAzDbClusterListProvider( final String fetchWriterNodeQueryHeader ) { super( - connectionInfo, + connectConfig, servicesContainer, topologyQuery, nodeIdQuery, diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java index 4ce035c6d..72526e112 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/ClusterTopologyMonitorImpl.java @@ -510,7 +510,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.servicesContainer.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.servicesContainer.getPluginService().getConnectionInfo() + this.servicesContainer.getPluginService().getConnectConfig() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java index fea674ae1..e60553d26 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsHostListProvider.java @@ -31,7 +31,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.Topology; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; @@ -56,18 +56,18 @@ public class MonitoringRdsHostListProvider extends RdsHostListProvider protected final String writerTopologyQuery; public MonitoringRdsHostListProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, final String isReaderQuery, final String writerTopologyQuery) { - super(connectionInfo, servicesContainer, topologyQuery, nodeIdQuery, isReaderQuery); + super(connectConfig, servicesContainer, topologyQuery, nodeIdQuery, isReaderQuery); this.servicesContainer = servicesContainer; this.pluginService = servicesContainer.getPluginService(); this.writerTopologyQuery = writerTopologyQuery; this.highRefreshRateNano = TimeUnit.MILLISECONDS.toNanos( - CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.getLong(this.connectionInfo.getProps())); + CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.getLong(this.connectConfig.getProps())); } public static void clearCache() { @@ -86,12 +86,12 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.connectionInfo, + this.connectConfig, (servicesContainer) -> new ClusterTopologyMonitorImpl( this.servicesContainer, this.clusterId, this.initialHostSpec, - this.connectionInfo.getProps(), + this.connectConfig.getProps(), this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, @@ -127,7 +127,7 @@ protected void clusterIdChanged(final String oldClusterId) throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.connectionInfo, + this.connectConfig, (servicesContainer) -> existingMonitor); assert monitorService.get(ClusterTopologyMonitorImpl.class, this.clusterId) == existingMonitor; existingMonitor.setClusterId(this.clusterId); diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java index 30c2c77da..fff4234a0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/monitoring/MonitoringRdsMultiAzHostListProvider.java @@ -19,7 +19,7 @@ import java.sql.SQLException; import java.util.logging.Logger; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class MonitoringRdsMultiAzHostListProvider extends MonitoringRdsHostListProvider { @@ -29,7 +29,7 @@ public class MonitoringRdsMultiAzHostListProvider extends MonitoringRdsHostListP protected final String fetchWriterNodeColumnName; public MonitoringRdsMultiAzHostListProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final FullServicesContainer servicesContainer, final String topologyQuery, final String nodeIdQuery, @@ -37,7 +37,7 @@ public MonitoringRdsMultiAzHostListProvider( final String fetchWriterNodeQuery, final String fetchWriterNodeColumnName) { super( - connectionInfo, + connectConfig, servicesContainer, topologyQuery, nodeIdQuery, @@ -55,12 +55,12 @@ protected ClusterTopologyMonitor initMonitor() throws SQLException { this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.servicesContainer.getDefaultConnectionProvider(), - this.connectionInfo, + this.connectConfig, (servicesContainer) -> new MultiAzClusterTopologyMonitorImpl( servicesContainer, this.clusterId, this.initialHostSpec, - this.connectionInfo.getProps(), + this.connectConfig.getProps(), this.clusterInstanceTemplate, this.refreshRateNano, this.highRefreshRateNano, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java index 2a0561898..14c0b8e83 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AbstractConnectionPlugin.java @@ -29,7 +29,7 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.OldConnectionSuggestedAction; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public abstract class AbstractConnectionPlugin implements ConnectionPlugin { @@ -50,7 +50,7 @@ public T execute( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) @@ -60,7 +60,7 @@ public Connection connect( @Override public Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) @@ -87,7 +87,7 @@ public HostSpec getHostSpecByStrategy(final List hosts, final HostRole @Override public void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java index 43796f425..138020ae0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java @@ -36,7 +36,7 @@ import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin { @@ -84,7 +84,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java index 388d993ac..6093c19f6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java @@ -39,7 +39,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlugin { @@ -120,7 +120,7 @@ public Set getSubscribedMethods() { @Override public void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -133,13 +133,13 @@ public void initHostProvider( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { final RdsUrlType type = this.rdsUtils.identifyRdsType(hostSpec.getHost()); - final Properties props = connectionInfo.getProps(); + final Properties props = connectConfig.getProps(); if (type == RdsUrlType.RDS_WRITER_CLUSTER || isInitialConnection && this.verifyOpenedConnectionType == VerifyOpenedConnectionType.WRITER) { Connection writerCandidateConn = this.getVerifiedWriterConnection(props, isInitialConnection, connectFunc); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java index ec5e24688..0fc4d5058 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java @@ -50,7 +50,7 @@ import software.amazon.jdbc.util.Pair; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -185,12 +185,12 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionInfo.getProps(), connectFunc); + return connectInternal(hostSpec, connectConfig.getProps(), connectFunc); } private Connection connectInternal(HostSpec hostSpec, Properties props, @@ -226,12 +226,12 @@ private Connection connectInternal(HostSpec hostSpec, Properties props, @Override public Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionInfo.getProps(), forceConnectFunc); + return connectInternal(hostSpec, connectConfig.getProps(), forceConnectFunc); } /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java index 313aeb666..ccb56d843 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java @@ -26,7 +26,7 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class ConnectTimeConnectionPlugin extends AbstractConnectionPlugin { @@ -44,7 +44,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { @@ -63,7 +63,7 @@ public Connection connect( @Override public Connection forceConnect( - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java index ed51292a3..b8706c0fd 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/DefaultConnectionPlugin.java @@ -45,7 +45,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlMethodAnalyzer; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; @@ -154,20 +154,20 @@ public T execute( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - ConnectionProvider connProvider = this.connProviderManager.getConnectionProvider(connectionInfo, hostSpec); + ConnectionProvider connProvider = this.connProviderManager.getConnectionProvider(connectConfig, hostSpec); // It's guaranteed that this plugin is always the last in plugin chain so connectFunc can be // ignored. - return connectInternal(connectionInfo, hostSpec, connProvider, isInitialConnection); + return connectInternal(connectConfig, hostSpec, connProvider, isInitialConnection); } private Connection connectInternal( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final ConnectionProvider connProvider, final boolean isInitialConnection) throws SQLException { @@ -177,14 +177,14 @@ private Connection connectInternal( Connection conn; try { - conn = connProvider.connect(connectionInfo, hostSpec); + conn = connProvider.connect(connectConfig, hostSpec); } finally { if (telemetryContext != null) { telemetryContext.closeContext(); } } - this.connProviderManager.initConnection(conn, connectionInfo, hostSpec); + this.connProviderManager.initConnection(conn, connectConfig, hostSpec); this.pluginService.setAvailability(hostSpec.asAliases(), HostAvailability.AVAILABLE); if (isInitialConnection) { @@ -196,7 +196,7 @@ private Connection connectInternal( @Override public Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) @@ -204,7 +204,7 @@ public Connection forceConnect( // It's guaranteed that this plugin is always the last in plugin chain so forceConnectFunc can be // ignored. - return connectInternal(connectionInfo, hostSpec, this.defaultConnProvider, isInitialConnection); + return connectInternal(connectConfig, hostSpec, this.defaultConnProvider, isInitialConnection); } @Override @@ -242,7 +242,7 @@ public HostSpec getHostSpecByStrategy(final List hosts, final HostRole @Override public void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java index c63e3740c..2e92202ad 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenConnectionPlugin.java @@ -42,7 +42,7 @@ import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -127,7 +127,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java index 69db5d091..77cfe037c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/bluegreen/BlueGreenStatusMonitor.java @@ -57,7 +57,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class BlueGreenStatusMonitor { @@ -599,7 +599,7 @@ protected void initHostListProvider() { return; } - final ConnectionInfo originalContext = this.pluginService.getConnectionInfo(); + final ConnectConfig originalContext = this.pluginService.getConnectConfig(); final Properties hostListProperties = originalContext.getProps(); // Need to instantiate a separate HostListProvider with @@ -616,7 +616,7 @@ protected void initHostListProvider() { if (connectionHostSpecCopy != null) { String hostListProviderUrl = String.format("%s%s/", originalContext.getProtocol(), connectionHostSpecCopy.getHostAndPort()); - ConnectionInfo newContext = new ConnectionInfo( + ConnectConfig newContext = new ConnectConfig( hostListProviderUrl, originalContext.getProtocol(), originalContext.getDriverDialect(), hostListProperties); this.hostListProvider = this.pluginService.getDialect().getHostListProvider().getProvider(newContext, this.servicesContainer); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java index 42ab0e189..1c77e9468 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.monitoring.MonitorErrorResponse; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -164,7 +164,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -217,7 +217,7 @@ protected CustomEndpointMonitor createMonitorIfAbsent(Properties props) throws S this.servicesContainer.getStorageService(), this.pluginService.getTelemetryFactory(), this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getConnectionInfo(), + this.pluginService.getConnectConfig(), (servicesContainer) -> new CustomEndpointMonitorImpl( servicesContainer.getStorageService(), servicesContainer.getTelemetryFactory(), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java index ba62818f8..e92cf6db8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPlugin.java @@ -28,7 +28,7 @@ import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class DeveloperConnectionPlugin extends AbstractConnectionPlugin implements ExceptionSimulator { @@ -142,28 +142,28 @@ protected void raiseException( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - this.raiseExceptionOnConnectIfNeeded(connectionInfo, hostSpec, isInitialConnection); - return super.connect(connectionInfo, hostSpec, isInitialConnection, connectFunc); + this.raiseExceptionOnConnectIfNeeded(connectConfig, hostSpec, isInitialConnection); + return super.connect(connectConfig, hostSpec, isInitialConnection, connectFunc); } @Override public Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { - this.raiseExceptionOnConnectIfNeeded(connectionInfo, hostSpec, isInitialConnection); - return super.connect(connectionInfo, hostSpec, isInitialConnection, forceConnectFunc); + this.raiseExceptionOnConnectIfNeeded(connectConfig, hostSpec, isInitialConnection); + return super.connect(connectConfig, hostSpec, isInitialConnection, forceConnectFunc); } protected void raiseExceptionOnConnectIfNeeded( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection) throws SQLException { @@ -173,7 +173,7 @@ protected void raiseExceptionOnConnectIfNeeded( } else if (ExceptionSimulatorManager.connectCallback != null) { this.raiseExceptionOnConnect( ExceptionSimulatorManager.connectCallback.getExceptionToRaise( - connectionInfo, hostSpec, isInitialConnection)); + connectConfig, hostSpec, isInitialConnection)); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java index 28f8a6fe7..ae4c187e0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/dev/ExceptionSimulatorConnectCallback.java @@ -18,11 +18,11 @@ import java.sql.SQLException; import software.amazon.jdbc.HostSpec; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public interface ExceptionSimulatorConnectCallback { SQLException getExceptionToRaise( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java index c0359c4e7..c42183ede 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; /** * Monitor the server while the connection is executing methods for more sophisticated failure @@ -270,7 +270,7 @@ public OldConnectionSuggestedAction notifyConnectionChanged(final EnumSet connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java index 7879a7f40..e0ba7e625 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitorServiceImpl.java @@ -156,7 +156,7 @@ protected HostMonitor getMonitor( this.serviceContainer.getStorageService(), this.telemetryFactory, this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getConnectionInfo(), + this.pluginService.getConnectConfig(), (servicesContainer) -> new HostMonitorImpl( servicesContainer, hostSpec, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java index 134234656..fe81c3927 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java @@ -41,7 +41,7 @@ import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; /** * Monitor the server while the connection is executing methods for more sophisticated failure @@ -228,7 +228,7 @@ public OldConnectionSuggestedAction notifyConnectionChanged(final EnumSet connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index 19c986edb..2be799bfc 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -364,7 +364,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.pluginService.getConnectionInfo() + this.pluginService.getConnectConfig() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index d171387cf..43dc6ce69 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -174,7 +174,7 @@ protected FullServicesContainer getNewServicesContainer() throws SQLException { this.servicesContainer.getMonitorService(), this.pluginService.getDefaultConnectionProvider(), this.servicesContainer.getTelemetryFactory(), - this.pluginService.getConnectionInfo() + this.pluginService.getConnectConfig() ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index d7a2b88c2..78584dded 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -53,7 +53,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -302,7 +302,7 @@ public T execute( @Override public void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -904,7 +904,7 @@ private boolean canDirectExecute(final String methodName) { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -913,7 +913,7 @@ public Connection connect( Connection conn = null; try { conn = this.staleDnsHelper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectConfig, hostSpec, connectFunc); } catch (final SQLException e) { if (!this.enableConnectFailover || !shouldExceptionTriggerConnectionSwitch(e)) { throw e; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java index 4c831043d..ec8c961a4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java @@ -53,7 +53,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -233,7 +233,7 @@ public T execute( @Override public void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -727,17 +727,17 @@ protected void initFailoverMode() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { this.initFailoverMode(); Connection conn = null; - Properties props = connectionInfo.getProps(); + Properties props = connectConfig.getProps(); if (!ENABLE_CONNECT_FAILOVER.getBoolean(props)) { return this.staleDnsHelper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectConfig, hostSpec, connectFunc); } final HostSpec hostSpecWithAvailability = this.pluginService.getHosts().stream() @@ -750,7 +750,7 @@ public Connection connect( try { conn = this.staleDnsHelper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectConfig, hostSpec, connectFunc); } catch (final SQLException e) { if (!this.shouldExceptionTriggerConnectionSwitch(e)) { throw e; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java index b45e06c8b..6315282ad 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -42,7 +42,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -153,21 +153,21 @@ public FederatedAuthPlugin(final PluginService pluginService, @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionInfo.getProps(), connectFunc); + return connectInternal(hostSpec, connectConfig.getProps(), connectFunc); } @Override public Connection forceConnect( - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec, final boolean isInitialConnection, final @NonNull JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionInfo.getProps(), forceConnectFunc); + return connectInternal(hostSpec, connectConfig.getProps(), forceConnectFunc); } private Connection connectInternal( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java index 2a80104c7..3bd2ea31c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -134,21 +134,21 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectionInfo.getProps(), connectFunc); + return connectInternal(hostSpec, connectConfig.getProps(), connectFunc); } @Override public Connection forceConnect( - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, HostSpec hostSpec, boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, connectionInfo.getProps(), forceConnectFunc); + return connectInternal(hostSpec, connectConfig.getProps(), forceConnectFunc); } private Connection connectInternal(final HostSpec hostSpec, final Properties props, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java index d314d6edf..1fa787778 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; @@ -107,18 +107,18 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(connectionInfo, hostSpec, connectFunc); + return connectInternal(connectConfig, hostSpec, connectFunc); } private Connection connectInternal( - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, HostSpec hostSpec, JdbcCallable connectFunc) throws SQLException { - Properties props = connectionInfo.getProps(); + Properties props = connectConfig.getProps(); if (StringUtils.isNullOrEmpty(PropertyDefinition.USER.getString(props))) { throw new SQLException(PropertyDefinition.USER.name + " is null or empty."); } @@ -226,12 +226,12 @@ private Connection connectInternal( @Override public Connection forceConnect( - final @NonNull ConnectionInfo connectionInfo, + final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec, final boolean isInitialConnection, final @NonNull JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(connectionInfo, hostSpec, forceConnectFunc); + return connectInternal(connectConfig, hostSpec, forceConnectFunc); } public static void clearCache() { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java index 63bcf3094..10832a975 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java @@ -35,7 +35,7 @@ import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class LimitlessConnectionPlugin extends AbstractConnectionPlugin { @@ -104,7 +104,7 @@ public LimitlessConnectionPlugin( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -130,7 +130,7 @@ public Connection connect( final LimitlessConnectionContext context = new LimitlessConnectionContext( hostSpec, - connectionInfo.getProps(), + connectConfig.getProps(), conn, connectFunc, null, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java index 0ae98fed6..4d0ff10ba 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java @@ -323,7 +323,7 @@ public void startMonitoring(final @NonNull HostSpec hostSpec, final @NonNull Pro this.servicesContainer.getStorageService(), this.servicesContainer.getTelemetryFactory(), this.pluginService.getDefaultConnectionProvider(), - this.pluginService.getConnectionInfo(), + this.pluginService.getConnectConfig(), (servicesContainer) -> new LimitlessRouterMonitor( servicesContainer, hostSpec, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index 2c3482f12..e539432a9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -44,7 +44,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implements CanReleaseResources { @@ -125,7 +125,7 @@ public Set getSubscribedMethods() { @Override public void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { @@ -135,7 +135,7 @@ public void initHostProvider( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java index e82f65214..10d676c60 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java @@ -33,7 +33,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.Utils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -61,7 +61,7 @@ public AuroraStaleDnsHelper(final PluginService pluginService) { public Connection getVerifiedConnection( final boolean isInitialConnection, final HostListProviderService hostListProviderService, - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final JdbcCallable connectFunc) throws SQLException { @@ -147,7 +147,7 @@ public Connection getVerifiedConnection( ); } - final Connection writerConn = this.pluginService.connect(this.writerHostSpec, connectionInfo.getProps()); + final Connection writerConn = this.pluginService.connect(this.writerHostSpec, connectConfig.getProps()); if (isInitialConnection) { hostListProviderService.setInitialConnectionHostSpec(this.writerHostSpec); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java index c1f622b0b..d5c2ba2d5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java @@ -32,7 +32,7 @@ import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; /** * After Aurora DB cluster fail over is completed and a cluster has elected a new writer node, the corresponding @@ -76,17 +76,17 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { return this.helper.getVerifiedConnection( - isInitialConnection, this.hostListProviderService, connectionInfo, hostSpec, connectFunc); + isInitialConnection, this.hostListProviderService, connectConfig, hostSpec, connectFunc); } @Override public void initHostProvider( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostListProviderService hostListProviderService, final JdbcCallable initHostProviderFunc) throws SQLException { this.hostListProviderService = hostListProviderService; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java index ae06497a9..7ddb3037b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.RandomHostSelector; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.CacheMap; public class FastestResponseStrategyPlugin extends AbstractConnectionPlugin { @@ -112,7 +112,7 @@ public Set getSubscribedMethods() { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java index bbd31ec1d..2514a335a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/HostResponseTimeServiceImpl.java @@ -79,7 +79,7 @@ public void setHosts(final @NonNull List hosts) { servicesContainer.getStorageService(), servicesContainer.getTelemetryFactory(), servicesContainer.getDefaultConnectionProvider(), - this.pluginService.getConnectionInfo(), + this.pluginService.getConnectConfig(), (servicesContainer) -> new NodeResponseTimeMonitor(pluginService, hostSpec, this.props, this.intervalMs)); } catch (SQLException e) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java index 1364f3c2f..a47d434b6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/ServiceUtility.java @@ -22,7 +22,7 @@ import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.PartialPluginService; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -59,20 +59,20 @@ public FullServicesContainer createServiceContainer( MonitorService monitorService, ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, - ConnectionInfo connectionInfo) throws SQLException { + ConnectConfig connectConfig) throws SQLException { FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, connectionProvider, telemetryFactory); ConnectionPluginManager pluginManager = new ConnectionPluginManager( connectionProvider, null, null, telemetryFactory); servicesContainer.setConnectionPluginManager(pluginManager); - PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, connectionInfo); + PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, connectConfig); servicesContainer.setHostListProviderService(partialPluginService); servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); - Properties propsCopy = PropertyUtils.copyProperties(connectionInfo.getProps()); + Properties propsCopy = PropertyUtils.copyProperties(connectConfig.getProps()); pluginManager.init(servicesContainer, propsCopy, partialPluginService, null); return servicesContainer; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionInfo.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectConfig.java similarity index 92% rename from wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionInfo.java rename to wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectConfig.java index 75dbc7a0f..f8c7eb439 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionInfo.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectConfig.java @@ -21,7 +21,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.ConnectionUrlParser; -public class ConnectionInfo { +public class ConnectConfig { protected static final ConnectionUrlParser connectionUrlParser = new ConnectionUrlParser(); protected final String initialConnectionString; protected final String protocol; @@ -29,11 +29,11 @@ public class ConnectionInfo { protected final Properties props; protected Dialect dbDialect; - public ConnectionInfo(String initialConnectionString, TargetDriverDialect driverDialect, Properties props) { + public ConnectConfig(String initialConnectionString, TargetDriverDialect driverDialect, Properties props) { this(initialConnectionString, connectionUrlParser.getProtocol(initialConnectionString), driverDialect, props); } - public ConnectionInfo( + public ConnectConfig( String initialConnectionString, String protocol, TargetDriverDialect driverDialect, Properties props) { this.initialConnectionString = initialConnectionString; this.protocol = protocol; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java index 6e29ba7a7..40fca65ff 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/connection/ConnectionServiceImpl.java @@ -39,7 +39,7 @@ */ @Deprecated public class ConnectionServiceImpl implements ConnectionService { - protected final ConnectionInfo connectionInfo; + protected final ConnectConfig connectConfig; protected final ConnectionPluginManager pluginManager; protected final PluginService pluginService; @@ -54,8 +54,8 @@ public ConnectionServiceImpl( MonitorService monitorService, TelemetryFactory telemetryFactory, ConnectionProvider connectionProvider, - ConnectionInfo connectionInfo) throws SQLException { - this.connectionInfo = connectionInfo; + ConnectConfig connectConfig) throws SQLException { + this.connectConfig = connectConfig; FullServicesContainer servicesContainer = new FullServicesContainerImpl(storageService, monitorService, connectionProvider, telemetryFactory); @@ -65,20 +65,20 @@ public ConnectionServiceImpl( null, telemetryFactory); servicesContainer.setConnectionPluginManager(this.pluginManager); - PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, this.connectionInfo); + PartialPluginService partialPluginService = new PartialPluginService(servicesContainer, this.connectConfig); servicesContainer.setHostListProviderService(partialPluginService); servicesContainer.setPluginService(partialPluginService); servicesContainer.setPluginManagerService(partialPluginService); this.pluginService = partialPluginService; - this.pluginManager.init(servicesContainer, this.connectionInfo.getProps(), partialPluginService, null); + this.pluginManager.init(servicesContainer, this.connectConfig.getProps(), partialPluginService, null); } @Override @Deprecated public Connection open(HostSpec hostSpec, Properties props) throws SQLException { - return this.pluginManager.forceConnect(this.connectionInfo, hostSpec, true, null); + return this.pluginManager.forceConnect(this.connectConfig, hostSpec, true, null); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java index 019a742a2..abce3f9a9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorService.java @@ -20,7 +20,7 @@ import java.util.Set; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.ConnectionProvider; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -59,7 +59,7 @@ void registerMonitorTypeIfAbsent( * @param telemetryFactory the telemetry factory for creating telemetry data. * @param defaultConnectionProvider the connection provider to use to create new connections if the monitor * requires it. - * @param connectionInfo the connection info for the original connection. + * @param connectConfig the connection info for the original connection. * @param initializer an initializer function to use to create the monitor if it does not already exist. * @param the type of the monitor. * @return the new or existing monitor. @@ -71,7 +71,7 @@ T runIfAbsent( StorageService storageService, TelemetryFactory telemetryFactory, ConnectionProvider defaultConnectionProvider, - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, MonitorInitializer initializer) throws SQLException; /** diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java index d00ab68d1..265c7b5a9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/monitoring/MonitorServiceImpl.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.ServiceUtility; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.events.DataAccessEvent; import software.amazon.jdbc.util.events.Event; import software.amazon.jdbc.util.events.EventPublisher; @@ -179,7 +179,7 @@ public T runIfAbsent( StorageService storageService, TelemetryFactory telemetryFactory, ConnectionProvider defaultConnectionProvider, - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, MonitorInitializer initializer) throws SQLException { CacheContainer cacheContainer = monitorCaches.get(monitorClass); if (cacheContainer == null) { @@ -201,7 +201,7 @@ public T runIfAbsent( storageService, defaultConnectionProvider, telemetryFactory, - connectionInfo); + connectConfig); final MonitorItem monitorItemInner = new MonitorItem(() -> initializer.createMonitor(servicesContainer)); monitorItemInner.getMonitor().start(); return monitorItemInner; @@ -228,13 +228,13 @@ protected FullServicesContainer getNewServicesContainer( StorageService storageService, ConnectionProvider connectionProvider, TelemetryFactory telemetryFactory, - ConnectionInfo connectionInfo) throws SQLException { + ConnectConfig connectConfig) throws SQLException { return ServiceUtility.getInstance().createServiceContainer( storageService, this, connectionProvider, telemetryFactory, - connectionInfo + connectConfig ); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java index a938728f8..40d7776f9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java @@ -57,7 +57,7 @@ import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -66,7 +66,7 @@ public class ConnectionWrapper implements Connection, CanReleaseResources { private static final Logger LOGGER = Logger.getLogger(ConnectionWrapper.class.getName()); - protected ConnectionInfo connectionInfo; + protected ConnectConfig connectConfig; protected ConnectionPluginManager pluginManager; protected TelemetryFactory telemetryFactory; protected PluginService pluginService; @@ -91,7 +91,7 @@ public ConnectionWrapper( throw new IllegalArgumentException("url"); } - this.connectionInfo = new ConnectionInfo(url, driverDialect, props); + this.connectConfig = new ConnectConfig(url, driverDialect, props); this.configurationProfile = configurationProfile; final ConnectionPluginManager pluginManager = @@ -103,7 +103,7 @@ public ConnectionWrapper( servicesContainer.setConnectionPluginManager(pluginManager); final PluginServiceImpl pluginService = - new PluginServiceImpl(servicesContainer, this.connectionInfo, this.configurationProfile); + new PluginServiceImpl(servicesContainer, this.connectConfig, this.configurationProfile); servicesContainer.setHostListProviderService(pluginService); servicesContainer.setPluginService(pluginService); servicesContainer.setPluginManagerService(pluginService); @@ -159,15 +159,15 @@ protected void init(final Properties props, final FullServicesContainer services final HostListProviderSupplier supplier = this.pluginService.getDialect().getHostListProvider(); if (supplier != null) { - final HostListProvider provider = supplier.getProvider(this.connectionInfo, servicesContainer); + final HostListProvider provider = supplier.getProvider(this.connectConfig, servicesContainer); hostListProviderService.setHostListProvider(provider); } - this.pluginManager.initHostProvider(this.connectionInfo, this.hostListProviderService); + this.pluginManager.initHostProvider(this.connectConfig, this.hostListProviderService); this.pluginService.refreshHostList(); if (this.pluginService.getCurrentConnection() == null) { final Connection conn = this.pluginManager.connect( - this.connectionInfo, this.pluginService.getInitialConnectionHostSpec(), true, null); + this.connectConfig, this.pluginService.getInitialConnectionHostSpec(), true, null); if (conn == null) { throw new SQLException(Messages.get("ConnectionWrapper.connectionNotOpen"), SqlState.UNKNOWN_STATE.getState()); } diff --git a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java index c78ce3f16..f695ff43f 100644 --- a/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java +++ b/wrapper/src/test/java/integration/container/aurora/TestAuroraHostListProvider.java @@ -18,13 +18,13 @@ import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class TestAuroraHostListProvider extends AuroraHostListProvider { public TestAuroraHostListProvider( - ConnectionInfo connectionInfo, FullServicesContainer servicesContainer) { - super(connectionInfo, servicesContainer, "", "", ""); + ConnectConfig connectConfig, FullServicesContainer servicesContainer) { + super(connectConfig, servicesContainer, "", "", ""); } public static void clearCache() { diff --git a/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java b/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java index 7d96686eb..50d108610 100644 --- a/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java +++ b/wrapper/src/test/java/integration/container/aurora/TestPluginServiceImpl.java @@ -20,14 +20,14 @@ import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.PluginServiceImpl; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class TestPluginServiceImpl extends PluginServiceImpl { public TestPluginServiceImpl( - @NonNull FullServicesContainer servicesContainer, @NonNull ConnectionInfo connectionInfo) + @NonNull FullServicesContainer servicesContainer, @NonNull ConnectConfig connectConfig) throws SQLException { - super(servicesContainer, connectionInfo); + super(servicesContainer, connectConfig); } public static void clearHostAvailabilityCache() { diff --git a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java index 6a89a2f1b..8c1749b64 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java @@ -64,7 +64,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.wrapper.ConnectionWrapper; @@ -76,7 +76,7 @@ public class ConnectionPluginManagerTests { @Mock JdbcCallable mockSqlFunction; @Mock ConnectionProvider mockConnectionProvider; @Mock ConnectionWrapper mockConnectionWrapper; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock TelemetryFactory mockTelemetryFactory; @Mock TelemetryContext mockTelemetryContext; @Mock FullServicesContainer mockServicesContainer; @@ -242,7 +242,7 @@ public void testConnect() throws Exception { new ConnectionPluginManager(mockConnectionProvider, null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); - final Connection conn = target.connect(mockConnectionInfo, + final Connection conn = target.connect(mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null); @@ -275,7 +275,7 @@ public void testConnectWithSkipPlugin() throws Exception { null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); final Connection conn = target.connect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, pluginOne); @@ -308,7 +308,7 @@ public void testForceConnect() throws Exception { null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); final Connection conn = target.forceConnect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null); @@ -341,7 +341,7 @@ public void testConnectWithSQLExceptionBefore() { assertThrows( SQLException.class, () -> target.connect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null)); @@ -370,7 +370,7 @@ public void testConnectWithSQLExceptionAfter() { assertThrows( SQLException.class, () -> target.connect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null)); @@ -402,7 +402,7 @@ public void testConnectWithUnexpectedExceptionBefore() { assertThrows( IllegalArgumentException.class, () -> target.connect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null)); @@ -431,7 +431,7 @@ public void testConnectWithUnexpectedExceptionAfter() { assertThrows( IllegalArgumentException.class, () -> target.connect( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), true, null)); @@ -537,7 +537,7 @@ public void testForceConnectCachedJdbcCallForceConnect() throws Exception { null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); Object result = target.forceConnect( - mockConnectionInfo, + mockConnectConfig, testHostSpec, true, null); @@ -556,7 +556,7 @@ public void testForceConnectCachedJdbcCallForceConnect() throws Exception { calls.clear(); result = target.forceConnect( - mockConnectionInfo, + mockConnectConfig, testHostSpec, true, null); diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java index 9ff354b2c..bd1a2d4d2 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java @@ -52,7 +52,7 @@ import software.amazon.jdbc.exceptions.ExceptionManager; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.StorageService; public class DialectDetectionTests { @@ -95,15 +95,15 @@ void cleanUp() throws Exception { PluginServiceImpl getPluginService(String host, String protocol) throws SQLException { return getPluginService( - new ConnectionInfo(protocol + host, protocol, mockDriverDialect, new Properties())); + new ConnectConfig(protocol + host, protocol, mockDriverDialect, new Properties())); } - PluginServiceImpl getPluginService(ConnectionInfo connectionInfo) throws SQLException { + PluginServiceImpl getPluginService(ConnectConfig connectConfig) throws SQLException { PluginServiceImpl pluginService = spy( new PluginServiceImpl( mockServicesContainer, new ExceptionManager(), - connectionInfo, + connectConfig, null, null, null)); @@ -115,10 +115,10 @@ PluginServiceImpl getPluginService(ConnectionInfo connectionInfo) throws SQLExce @ParameterizedTest @MethodSource("getInitialDialectArguments") public void testInitialDialectDetection(String protocol, String host, Object expectedDialect) throws SQLException { - final ConnectionInfo connectionInfo = - new ConnectionInfo(protocol + host, protocol, mockDriverDialect, new Properties()); + final ConnectConfig connectConfig = + new ConnectConfig(protocol + host, protocol, mockDriverDialect, new Properties()); final DialectManager dialectManager = new DialectManager(this.getPluginService(host, protocol)); - final Dialect dialect = dialectManager.getDialect(connectionInfo); + final Dialect dialect = dialectManager.getDialect(connectConfig); assertEquals(expectedDialect, dialect.getClass()); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java index cb3e549bf..d2a2854a4 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java @@ -50,12 +50,12 @@ import software.amazon.jdbc.targetdriverdialect.ConnectInfo; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Pair; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.storage.SlidingExpirationCache; class HikariPooledConnectionProviderTest { @Mock Connection mockConnection; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock HikariDataSource mockDataSource; @Mock HostSpec mockHostSpec; @Mock HikariConfig mockConfig; @@ -117,10 +117,10 @@ void init() throws SQLException { when(mxBeanWith1Connection.getActiveConnections()).thenReturn(1); when(dsWith2Connections.getHikariPoolMXBean()).thenReturn(mxBeanWith2Connections); when(mxBeanWith2Connections.getActiveConnections()).thenReturn(2); - when(mockConnectionInfo.getDriverDialect()).thenReturn(mockDriverDialect); - when(mockConnectionInfo.getDbDialect()).thenReturn(mockDbDialect); - when(mockConnectionInfo.getProps()).thenReturn(defaultProps); - when(mockConnectionInfo.getProtocol()).thenReturn(protocol); + when(mockConnectConfig.getDriverDialect()).thenReturn(mockDriverDialect); + when(mockConnectConfig.getDbDialect()).thenReturn(mockDbDialect); + when(mockConnectConfig.getProps()).thenReturn(defaultProps); + when(mockConnectConfig.getProtocol()).thenReturn(protocol); } @AfterEach @@ -144,7 +144,7 @@ void testConnectWithDefaultMapping() throws SQLException { doReturn(new ConnectInfo("url", new Properties())) .when(mockDriverDialect).prepareConnectInfo(anyString(), any(), any()); - try (Connection conn = provider.connect(mockConnectionInfo, mockHostSpec)) { + try (Connection conn = provider.connect(mockConnectConfig, mockHostSpec)) { assertEquals(mockConnection, conn); assertEquals(1, provider.getHostCount()); final Set hosts = provider.getHosts(); @@ -169,7 +169,7 @@ void testConnectWithCustomMapping() throws SQLException { Properties props = new Properties(); props.setProperty(PropertyDefinition.USER.name, user1); props.setProperty(PropertyDefinition.PASSWORD.name, password); - try (Connection conn = provider.connect(mockConnectionInfo, mockHostSpec)) { + try (Connection conn = provider.connect(mockConnectConfig, mockHostSpec)) { assertEquals(mockConnection, conn); assertEquals(1, provider.getHostCount()); final Set> keys = provider.getKeys(); @@ -184,11 +184,11 @@ public void testAcceptsUrl() { assertTrue( provider.acceptsUrl( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(readerUrl2Connection).build())); assertFalse( provider.acceptsUrl( - mockConnectionInfo, + mockConnectConfig, new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(clusterUrl).build())); } @@ -231,7 +231,7 @@ public void testConfigurePool() throws SQLException { doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) .when(mockDriverDialect).prepareConnectInfo(anyString(), any(), any()); - provider.configurePool(mockConfig, mockConnectionInfo, readerHost1Connection, defaultProps); + provider.configurePool(mockConfig, mockConnectConfig, readerHost1Connection, defaultProps); verify(mockConfig).setJdbcUrl(expectedJdbcUrl); verify(mockConfig).setUsername(user1); verify(mockConfig).setPassword(password); @@ -242,10 +242,10 @@ public void testConnectToDeletedInstance() throws SQLException { provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); doReturn(mockDataSource).when(provider) - .createHikariDataSource(eq(mockConnectionInfo), eq(readerHost1Connection), eq(defaultProps)); + .createHikariDataSource(eq(mockConnectConfig), eq(readerHost1Connection), eq(defaultProps)); when(mockDataSource.getConnection()).thenThrow(SQLException.class); assertThrows(SQLException.class, - () -> provider.connect(mockConnectionInfo, readerHost1Connection)); + () -> provider.connect(mockConnectConfig, readerHost1Connection)); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java index 8ac043902..8d0bce63c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java @@ -61,7 +61,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.storage.TestStorageServiceImpl; @@ -78,7 +78,7 @@ public class PluginServiceImplTests { @Mock Connection newConnection; @Mock Connection oldConnection; @Mock HostListProvider hostListProvider; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock Statement statement; @Mock ResultSet resultSet; @@ -94,9 +94,9 @@ void setUp() throws SQLException { when(statement.executeQuery(any())).thenReturn(resultSet); when(servicesContainer.getConnectionPluginManager()).thenReturn(pluginManager); when(servicesContainer.getStorageService()).thenReturn(storageService); - when(mockConnectionInfo.getProps()).thenReturn(props); - when(mockConnectionInfo.getInitialConnectionString()).thenReturn("url"); - when(mockConnectionInfo.getProtocol()).thenReturn("jdbc:postgresql://"); + when(mockConnectConfig.getProps()).thenReturn(props); + when(mockConnectConfig.getInitialConnectionString()).thenReturn("url"); + when(mockConnectConfig.getProtocol()).thenReturn("jdbc:postgresql://"); storageService = new TestStorageServiceImpl(mockEventPublisher); PluginServiceImpl.hostAvailabilityExpiringCache.clear(); } @@ -129,7 +129,7 @@ public void testOldConnectionNoSuggestion() throws SQLException { } protected PluginServiceImpl getPluginService() throws SQLException { - return new PluginServiceImpl(servicesContainer, mockConnectionInfo); + return new PluginServiceImpl(servicesContainer, mockConnectConfig); } @Test diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java index d064dfab2..ac10855cf 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsHostListProviderTest.java @@ -66,7 +66,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.storage.TestStorageServiceImpl; @@ -118,9 +118,9 @@ void tearDown() throws Exception { } private RdsHostListProvider getRdsHostListProvider(String originalUrl) throws SQLException { - ConnectionInfo connectionInfo = new ConnectionInfo(originalUrl, mockDriverDialect, new Properties()); + ConnectConfig connectConfig = new ConnectConfig(originalUrl, mockDriverDialect, new Properties()); RdsHostListProvider provider = new RdsHostListProvider( - connectionInfo, + connectConfig, mockServicesContainer, "foo", "bar", "baz"); provider.init(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java index e699d0917..251d96cf3 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/RdsMultiAzDbClusterListProviderTest.java @@ -60,7 +60,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider.FetchTopologyResult; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.storage.TestStorageServiceImpl; @@ -111,9 +111,9 @@ void tearDown() throws Exception { } private RdsMultiAzDbClusterListProvider getRdsMazDbClusterHostListProvider(String originalUrl) throws SQLException { - ConnectionInfo connectionInfo = new ConnectionInfo(originalUrl, mockDriverDialect, new Properties()); + ConnectConfig connectConfig = new ConnectConfig(originalUrl, mockDriverDialect, new Properties()); RdsMultiAzDbClusterListProvider provider = new RdsMultiAzDbClusterListProvider( - connectionInfo, + connectConfig, mockServicesContainer, "foo", "bar", diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java index d0093d090..5336ec64d 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginOne.java @@ -34,7 +34,7 @@ import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.OldConnectionSuggestedAction; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class TestPluginOne implements ConnectionPlugin { @@ -85,7 +85,7 @@ public T execute( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -98,7 +98,7 @@ public Connection connect( @Override public Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { @@ -130,7 +130,7 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin @Override public void initHostProvider( - ConnectionInfo connectionInfo, + ConnectConfig connectConfig, HostListProviderService hostListProviderService, JdbcCallable initHostProviderFunc) { // do nothing diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java index 7b1f85689..8f545c528 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java @@ -24,7 +24,7 @@ import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.JdbcMethod; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class TestPluginThree extends TestPluginOne { @@ -45,7 +45,7 @@ public TestPluginThree(ArrayList calls, Connection connection) { @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { @@ -65,7 +65,7 @@ public Connection connect( @Override public Connection forceConnect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable forceConnectFunc) throws SQLException { diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java index 9412d3b23..f682b7b47 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThrowException.java @@ -23,7 +23,7 @@ import java.util.HashSet; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class TestPluginThrowException extends TestPluginOne { @@ -76,7 +76,7 @@ public T execute( @Override public Connection connect( - final ConnectionInfo connectionInfo, + final ConnectConfig connectConfig, final HostSpec hostSpec, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java index 226aacffa..776678b8b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java @@ -54,7 +54,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class AuroraConnectionTrackerPluginTest { @@ -112,10 +112,10 @@ public void testTrackNewInstanceConnections( mockRdsUtils, mockTracker); - final ConnectionInfo connectionInfo = - new ConnectionInfo(protocol + hostSpec.getHost(), mockDriverDialect, EMPTY_PROPERTIES); + final ConnectConfig connectConfig = + new ConnectConfig(protocol + hostSpec.getHost(), mockDriverDialect, EMPTY_PROPERTIES); final Connection actualConnection = plugin.connect( - connectionInfo, + connectConfig, hostSpec, isInitialConnection, mockConnectionFunction); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java index 0d6f8853f..88de04621 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPluginTest.java @@ -69,7 +69,7 @@ import software.amazon.jdbc.util.FullServicesContainer; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Pair; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -127,7 +127,7 @@ public void init() throws SQLException { REGION_PROPERTY.set(TEST_PROPS, TEST_REGION); SECRET_ID_PROPERTY.set(TEST_PROPS, TEST_SECRET_ID); - when(mockDialectManager.getDialect(any(ConnectionInfo.class))) + when(mockDialectManager.getDialect(any(ConnectConfig.class))) .thenReturn(mockTopologyAwareDialect); when(mockServicesContainer.getConnectionPluginManager()).thenReturn(mockConnectionPluginManager); @@ -145,7 +145,7 @@ public void init() throws SQLException { (host, r) -> mockSecretsManagerClient, (id) -> mockGetValueRequest); - when(mockDialectManager.getDialect(any(ConnectionInfo.class))) + when(mockDialectManager.getDialect(any(ConnectConfig.class))) .thenReturn(mockTopologyAwareDialect); when(mockService.getHostSpecBuilder()).thenReturn(new HostSpecBuilder(new SimpleHostAvailabilityStrategy())); @@ -166,7 +166,7 @@ public void testConnectWithCachedSecrets() throws SQLException { // Add initial cached secret to be used for a connection. AwsSecretsManagerCacheHolder.secretsCache.put(SECRET_CACHE_KEY, TEST_SECRET); - this.plugin.connect(getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc); + this.plugin.connect(getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc); assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); verify(this.mockSecretsManagerClient, never()).getSecretValue(this.mockGetValueRequest); @@ -175,8 +175,8 @@ public void testConnectWithCachedSecrets() throws SQLException { assertEquals(TEST_PASSWORD, TEST_PROPS.get(PropertyDefinition.PASSWORD.name)); } - protected ConnectionInfo getConnectionInfo(String protocol) { - return new ConnectionInfo(protocol, mockDriverDialect, TEST_PROPS); + protected ConnectConfig getConnectConfig(String protocol) { + return new ConnectConfig(protocol, mockDriverDialect, TEST_PROPS); } /** @@ -188,7 +188,7 @@ public void testConnectWithNewSecrets() throws SQLException { when(this.mockSecretsManagerClient.getSecretValue(this.mockGetValueRequest)) .thenReturn(VALID_GET_SECRET_VALUE_RESPONSE); - this.plugin.connect(getConnectionInfo(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc); + this.plugin.connect(getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc); assertEquals(1, AwsSecretsManagerCacheHolder.secretsCache.size()); verify(this.mockSecretsManagerClient).getSecretValue(this.mockGetValueRequest); @@ -220,7 +220,7 @@ public void testFailedInitialConnectionWithUnhandledError() throws SQLException final SQLException connectionFailedException = assertThrows( SQLException.class, () -> this.plugin.connect( - getConnectionInfo(TEST_PG_PROTOCOL), + getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc)); @@ -262,7 +262,7 @@ public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( assertThrows( SQLException.class, () -> this.plugin.connect( - getConnectionInfo(TEST_PG_PROTOCOL), + getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc)); @@ -275,7 +275,7 @@ public void testConnectWithNewSecretsAfterTryingWithCachedSecrets( } private @NotNull PluginServiceImpl getPluginService(String protocol) throws SQLException { - return new PluginServiceImpl(mockServicesContainer, getConnectionInfo(protocol)); + return new PluginServiceImpl(mockServicesContainer, getConnectConfig(protocol)); } /** @@ -291,7 +291,7 @@ public void testFailedToReadSecrets() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - getConnectionInfo(TEST_PG_PROTOCOL), + getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc)); @@ -317,7 +317,7 @@ public void testFailedToGetSecrets() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - getConnectionInfo(TEST_PG_PROTOCOL), + getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc)); @@ -353,7 +353,7 @@ public void testFailedInitialConnectionWithWrappedGenericError(final String acce assertThrows( SQLException.class, () -> this.plugin.connect( - getConnectionInfo(TEST_PG_PROTOCOL), + getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc)); @@ -384,7 +384,7 @@ public void testConnectWithWrappedMySQLException() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - getConnectionInfo(TEST_MYSQL_PROTOCOL), + getConnectConfig(TEST_MYSQL_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc)); @@ -415,7 +415,7 @@ public void testConnectWithWrappedPostgreSQLException() throws SQLException { assertThrows( SQLException.class, () -> this.plugin.connect( - getConnectionInfo(TEST_PG_PROTOCOL), + getConnectConfig(TEST_PG_PROTOCOL), TEST_HOSTSPEC, true, this.connectFunc)); @@ -435,7 +435,7 @@ public void testConnectViaARN(final String arn, final Region expectedRegionParse SECRET_ID_PROPERTY.set(props, arn); this.plugin = spy(new AwsSecretsManagerConnectionPlugin( - new PluginServiceImpl(mockServicesContainer, getConnectionInfo(TEST_PG_PROTOCOL)), + new PluginServiceImpl(mockServicesContainer, getConnectConfig(TEST_PG_PROTOCOL)), props, (host, r) -> mockSecretsManagerClient, (id) -> mockGetValueRequest)); @@ -455,7 +455,7 @@ public void testConnectionWithRegionParameterAndARN(final String arn, final Regi REGION_PROPERTY.set(props, expectedRegion.toString()); this.plugin = spy(new AwsSecretsManagerConnectionPlugin( - new PluginServiceImpl(mockServicesContainer, getConnectionInfo(TEST_PG_PROTOCOL)), + new PluginServiceImpl(mockServicesContainer, getConnectConfig(TEST_PG_PROTOCOL)), props, (host, r) -> mockSecretsManagerClient, (id) -> mockGetValueRequest)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java index e36e9d4a4..5c6b0ca73 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/DefaultConnectionPluginTest.java @@ -48,7 +48,7 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.PluginManagerService; import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.GaugeCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -62,7 +62,7 @@ class DefaultConnectionPluginTest { @Mock PluginService pluginService; @Mock ConnectionProvider connectionProvider; @Mock PluginManagerService pluginManagerService; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock JdbcCallable mockSqlFunction; @Mock JdbcCallable mockConnectFunction; @Mock Connection conn; @@ -122,7 +122,7 @@ void testExecute_closeOldConnection() throws SQLException { @Test void testConnect() throws SQLException { - plugin.connect(mockConnectionInfo, mockHostSpec, true, mockConnectFunction); + plugin.connect(mockConnectConfig, mockHostSpec, true, mockConnectFunction); verify(connectionProvider, atLeastOnce()).connect(any(), any()); verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), any(), any()); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java index 902005830..b5866c81d 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java @@ -48,7 +48,7 @@ import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.monitoring.MonitorService; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -65,7 +65,7 @@ public class CustomEndpointPluginTest { private final HostSpec host = hostSpecBuilder.host(customEndpointUrl).build(); @Mock private FullServicesContainer mockServicesContainer; - @Mock private ConnectionInfo mockConnectionInfo; + @Mock private ConnectConfig mockConnectConfig; @Mock private PluginService mockPluginService; @Mock private MonitorService mockMonitorService; @Mock private BiFunction mockRdsClientFunc; @@ -108,7 +108,7 @@ private CustomEndpointPlugin getSpyPlugin() throws SQLException { public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { CustomEndpointPlugin spyPlugin = getSpyPlugin(); - spyPlugin.connect(mockConnectionInfo, writerClusterHost, true, mockConnectFunc); + spyPlugin.connect(mockConnectConfig, writerClusterHost, true, mockConnectFunc); verify(mockConnectFunc, times(1)).call(); verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); @@ -118,7 +118,7 @@ public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLExc public void testConnect_monitorCreated() throws SQLException { CustomEndpointPlugin spyPlugin = getSpyPlugin(); - spyPlugin.connect(mockConnectionInfo, host, true, mockConnectFunc); + spyPlugin.connect(mockConnectConfig, host, true, mockConnectFunc); verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); verify(mockConnectFunc, times(1)).call(); @@ -130,7 +130,7 @@ public void testConnect_timeoutWaitingForInfo() throws SQLException { CustomEndpointPlugin spyPlugin = getSpyPlugin(); when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); - assertThrows(SQLException.class, () -> spyPlugin.connect(mockConnectionInfo, host, true, mockConnectFunc)); + assertThrows(SQLException.class, () -> spyPlugin.connect(mockConnectConfig, host, true, mockConnectFunc)); verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); verify(mockConnectFunc, never()).call(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java index b739e02a8..b203fa0ea 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java @@ -63,7 +63,7 @@ import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; class HostMonitoringConnectionPluginTest { @@ -77,7 +77,7 @@ class HostMonitoringConnectionPluginTest { @Mock PluginService pluginService; @Mock Dialect mockDialect; @Mock Connection connection; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock Statement statement; @Mock ResultSet resultSet; Properties properties = new Properties(); @@ -263,7 +263,7 @@ void test_connect_exceptionRaisedDuringGenerateHostAliases() throws SQLException doThrow(new SQLException()).when(connection).createStatement(); // Ensure SQLException raised in `generateHostAliases` are ignored. - final Connection conn = plugin.connect(mockConnectionInfo, hostSpec, true, () -> connection); + final Connection conn = plugin.connect(mockConnectConfig, hostSpec, true, () -> connection); assertNotNull(conn); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java index f17fc3e2a..f137084e1 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java @@ -50,7 +50,7 @@ import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; import software.amazon.jdbc.plugin.iam.IamTokenUtility; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -77,7 +77,7 @@ class FederatedAuthPluginTest { @Mock private IamTokenUtility mockIamTokenUtils; @Mock private CompletableFuture completableFuture; @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; - @Mock private ConnectionInfo mockConnectionInfo; + @Mock private ConnectConfig mockConnectConfig; private Properties props; private AutoCloseable closeable; @@ -104,7 +104,7 @@ public void init() throws ExecutionException, InterruptedException, SQLException when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) .thenReturn(mockAwsCredentialsProvider); when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); - when(mockConnectionInfo.getProps()).thenReturn(props); + when(mockConnectConfig.getProps()).thenReturn(props); when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); } @@ -121,7 +121,7 @@ void testCachedToken() throws SQLException { String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + plugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -138,7 +138,7 @@ void testExpiredCachedToken() throws SQLException { someExpiredToken, Instant.now().minusMillis(300000)); FederatedAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); - spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + spyPlugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, Region.US_EAST_2, HOST_SPEC.getHost(), @@ -153,7 +153,7 @@ void testNoCachedToken() throws SQLException { FederatedAuthPlugin spyPlugin = Mockito.spy( new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + spyPlugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken( mockAwsCredentialsProvider, Region.US_EAST_2, @@ -180,7 +180,7 @@ void testSpecifiedIamHostPortRegion() throws SQLException { FederatedAuthPlugin plugin = new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + plugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -199,7 +199,7 @@ void testIdpCredentialsFallback() throws SQLException { String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; FederatedAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + plugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -213,7 +213,7 @@ public void testUsingIamHost() throws SQLException { FederatedAuthPlugin spyPlugin = Mockito.spy( new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + spyPlugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java index 149d2e5d1..039ce76bd 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPluginTest.java @@ -47,7 +47,7 @@ import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin; import software.amazon.jdbc.plugin.iam.IamTokenUtility; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -73,7 +73,7 @@ class OktaAuthPluginTest { @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; @Mock private RdsUtils mockRdsUtils; @Mock private IamTokenUtility mockIamTokenUtils; - @Mock private ConnectionInfo mockConnectionInfo; + @Mock private ConnectConfig mockConnectConfig; private Properties props; private AutoCloseable closeable; @@ -96,7 +96,7 @@ void setUp() throws SQLException { when(mockPluginService.getDialect()).thenReturn(mockDialect); when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockConnectionInfo.getProps()).thenReturn(props); + when(mockConnectConfig.getProps()).thenReturn(props); when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) @@ -116,7 +116,7 @@ void testCachedToken() throws SQLException { String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + plugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -133,7 +133,7 @@ void testExpiredCachedToken() throws SQLException { someExpiredToken, Instant.now().minusMillis(300000)); OktaAuthCacheHolder.tokenCache.put(key, expiredTokenInfo); - spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + spyPlugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken(mockAwsCredentialsProvider, Region.US_EAST_2, HOST_SPEC.getHost(), @@ -148,7 +148,7 @@ void testNoCachedToken() throws SQLException { final OktaAuthPlugin spyPlugin = new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + spyPlugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); verify(mockIamTokenUtils).generateAuthenticationToken( mockAwsCredentialsProvider, Region.US_EAST_2, @@ -175,7 +175,7 @@ void testSpecifiedIamHostPortRegion() throws SQLException { OktaAuthPlugin plugin = new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils); - plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + plugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -194,7 +194,7 @@ void testIdpCredentialsFallback() throws SQLException { final String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; OktaAuthCacheHolder.tokenCache.put(key, TEST_TOKEN_INFO); - plugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + plugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -208,7 +208,7 @@ public void testUsingIamHost() throws SQLException { OktaAuthPlugin spyPlugin = Mockito.spy( new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils)); - spyPlugin.connect(mockConnectionInfo, HOST_SPEC, true, mockLambda); + spyPlugin.connect(mockConnectConfig, HOST_SPEC, true, mockLambda); assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java index fe48edd8c..0d7361a87 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginTest.java @@ -55,7 +55,7 @@ import software.amazon.jdbc.plugin.TokenInfo; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -262,8 +262,8 @@ public void testTokenSetInProps(final String protocol, final HostSpec hostSpec) IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService, mockIamTokenUtils); doThrow(new SQLException()).when(mockLambda).call(); - ConnectionInfo connectionInfo = new ConnectionInfo(protocol + hostSpec.getHost(), mockDriverDialect, props); - assertThrows(SQLException.class, () -> targetPlugin.connect(connectionInfo, hostSpec, true, mockLambda)); + ConnectConfig connectConfig = new ConnectConfig(protocol + hostSpec.getHost(), mockDriverDialect, props); + assertThrows(SQLException.class, () -> targetPlugin.connect(connectConfig, hostSpec, true, mockLambda)); verify(mockLambda, times(1)).call(); assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); @@ -282,9 +282,9 @@ private void testGenerateToken( doThrow(new SQLException()).when(mockLambda).call(); - ConnectionInfo connectionInfo = new ConnectionInfo(protocol + hostSpec.getHost(), mockDriverDialect, props); + ConnectConfig connectConfig = new ConnectConfig(protocol + hostSpec.getHost(), mockDriverDialect, props); assertThrows(SQLException.class, - () -> spyPlugin.connect(connectionInfo, hostSpec, true, mockLambda)); + () -> spyPlugin.connect(connectConfig, hostSpec, true, mockLambda)); verify(mockIamTokenUtils).generateAuthenticationToken( any(DefaultCredentialsProvider.class), diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java index 281171b76..ef60fad94 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPluginTest.java @@ -43,7 +43,7 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.dialect.PgDialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class LimitlessConnectionPluginTest { @@ -53,7 +53,7 @@ public class LimitlessConnectionPluginTest { private static final Dialect supportedDialect = new AuroraPgDialect(); @Mock JdbcCallable mockConnectFuncLambda; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock private Connection mockConnection; @Mock private PluginService mockPluginService; @Mock private HostListProvider mockHostListProvider; @@ -90,7 +90,7 @@ void testConnect() throws SQLException { }).when(mockLimitlessRouterService).establishConnection(any()); final Connection expectedConnection = mockConnection; - final Connection actualConnection = plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, + final Connection actualConnection = plugin.connect(mockConnectConfig, INPUT_HOST_SPEC, true, mockConnectFuncLambda); assertEquals(expectedConnection, actualConnection); @@ -111,7 +111,7 @@ void testConnectGivenNullConnection() throws SQLException { assertThrows( SQLException.class, - () -> plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, mockConnectFuncLambda)); + () -> plugin.connect(mockConnectConfig, INPUT_HOST_SPEC, true, mockConnectFuncLambda)); verify(mockPluginService, times(1)).getDialect(); verify(mockConnectFuncLambda, times(0)).call(); @@ -127,7 +127,7 @@ void testConnectGivenUnsupportedDialect() throws SQLException { assertThrows( UnsupportedOperationException.class, - () -> plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, mockConnectFuncLambda)); + () -> plugin.connect(mockConnectConfig, INPUT_HOST_SPEC, true, mockConnectFuncLambda)); verify(mockPluginService, times(2)).getDialect(); verify(mockConnectFuncLambda, times(1)).call(); @@ -142,7 +142,7 @@ void testConnectGivenSupportedDialectAfterRefresh() throws SQLException { when(mockPluginService.getDialect()).thenReturn(unsupportedDialect, supportedDialect); final Connection expectedConnection = mockConnection; - final Connection actualConnection = plugin.connect(mockConnectionInfo, INPUT_HOST_SPEC, true, + final Connection actualConnection = plugin.connect(mockConnectConfig, INPUT_HOST_SPEC, true, mockConnectFuncLambda); assertEquals(expectedConnection, actualConnection); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index 0d9c678d7..22559d69e 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -55,7 +55,7 @@ import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; import software.amazon.jdbc.util.SqlState; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; public class ReadWriteSplittingPluginTest { private static final int TEST_PORT = 5432; @@ -101,7 +101,7 @@ public class ReadWriteSplittingPluginTest { @Mock private Connection mockReaderConn3; @Mock private Statement mockStatement; @Mock private ResultSet mockResultSet; - @Mock private ConnectionInfo mockConnectionInfo; + @Mock private ConnectConfig mockConnectConfig; @Mock private EnumSet mockChanges; @BeforeEach @@ -402,7 +402,7 @@ public void testConnectNonInitialConnection() throws SQLException { null); final Connection connection = - plugin.connect(mockConnectionInfo, writerHostSpec, false, this.mockConnectFunc); + plugin.connect(mockConnectConfig, writerHostSpec, false, this.mockConnectFunc); assertEquals(mockWriterConn, connection); verify(mockConnectFunc).call(); @@ -421,7 +421,7 @@ public void testConnectRdsInstanceUrl() throws SQLException { null, null); final Connection connection = plugin.connect( - mockConnectionInfo, + mockConnectConfig, instanceUrlHostSpec, true, this.mockConnectFunc); @@ -443,7 +443,7 @@ public void testConnectReaderIpUrl() throws SQLException { null, null); final Connection connection = - plugin.connect(mockConnectionInfo, ipUrlHostSpec, true, this.mockConnectFunc); + plugin.connect(mockConnectConfig, ipUrlHostSpec, true, this.mockConnectFunc); assertEquals(mockReaderConn1, connection); verify(mockConnectFunc).call(); @@ -459,7 +459,7 @@ public void testConnectClusterUrl() throws SQLException { null, null); final Connection connection = - plugin.connect(mockConnectionInfo, clusterUrlHostSpec, true, this.mockConnectFunc); + plugin.connect(mockConnectConfig, clusterUrlHostSpec, true, this.mockConnectFunc); assertEquals(mockWriterConn, connection); verify(mockConnectFunc).call(); @@ -480,7 +480,7 @@ public void testConnect_errorUpdatingHostSpec() throws SQLException { assertThrows( SQLException.class, () -> plugin.connect( - mockConnectionInfo, + mockConnectConfig, ipUrlHostSpec, true, this.mockConnectFunc)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java index 33ef71bcb..605f00576 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/monitoring/MonitorServiceImplTest.java @@ -39,7 +39,7 @@ import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; import software.amazon.jdbc.util.FullServicesContainer; -import software.amazon.jdbc.util.connection.ConnectionInfo; +import software.amazon.jdbc.util.connection.ConnectConfig; import software.amazon.jdbc.util.events.EventPublisher; import software.amazon.jdbc.util.storage.StorageService; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -47,7 +47,7 @@ class MonitorServiceImplTest { @Mock FullServicesContainer mockServicesContainer; @Mock StorageService mockStorageService; - @Mock ConnectionInfo mockConnectionInfo; + @Mock ConnectConfig mockConnectConfig; @Mock ConnectionProvider mockConnectionProvider; @Mock TelemetryFactory mockTelemetryFactory; @Mock EventPublisher mockPublisher; @@ -63,7 +63,7 @@ void setUp() throws SQLException { eq(mockStorageService), eq(mockConnectionProvider), eq(mockTelemetryFactory), - eq(mockConnectionInfo)); + eq(mockConnectConfig)); } @AfterEach @@ -88,7 +88,7 @@ public void testMonitorError_monitorReCreated() throws SQLException, Interrupted mockStorageService, mockTelemetryFactory, mockConnectionProvider, - mockConnectionInfo, + mockConnectConfig, (mockServicesContainer) -> new NoOpMonitor(30) ); @@ -128,7 +128,7 @@ public void testMonitorStuck_monitorReCreated() throws SQLException, Interrupted mockStorageService, mockTelemetryFactory, mockConnectionProvider, - mockConnectionInfo, + mockConnectConfig, (mockServicesContainer) -> new NoOpMonitor(30) ); @@ -170,7 +170,7 @@ public void testMonitorExpired() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - mockConnectionInfo, + mockConnectConfig, (mockServicesContainer) -> new NoOpMonitor(30) ); @@ -199,7 +199,7 @@ public void testMonitorMismatch() { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - mockConnectionInfo, + mockConnectConfig, // indicated monitor class is CustomEndpointMonitorImpl, but actual monitor is NoOpMonitor. The monitor // service should detect this and throw an exception. (mockServicesContainer) -> new NoOpMonitor(30) @@ -225,7 +225,7 @@ public void testRemove() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - mockConnectionInfo, + mockConnectConfig, (mockServicesContainer) -> new NoOpMonitor(30) ); assertNotNull(monitor); @@ -256,7 +256,7 @@ public void testStopAndRemove() throws SQLException, InterruptedException { mockStorageService, mockTelemetryFactory, mockConnectionProvider, - mockConnectionInfo, + mockConnectConfig, (mockServicesContainer) -> new NoOpMonitor(30) ); assertNotNull(monitor); From 6611719704f92b940c466c3e24aea6f728f2c02f Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 29 Sep 2025 09:00:27 -0700 Subject: [PATCH 54/54] ConnectInfo -> ConnectParams --- .../jdbc/C3P0PooledConnectionProvider.java | 18 ++++++++--------- .../amazon/jdbc/DriverConnectionProvider.java | 20 +++++++++---------- .../jdbc/HikariPooledConnectionProvider.java | 16 +++++++-------- .../{ConnectInfo.java => ConnectParams.java} | 8 ++++---- .../GenericTargetDriverDialect.java | 4 ++-- .../MariadbTargetDriverDialect.java | 8 ++++---- .../MysqlConnectorJTargetDriverDialect.java | 8 ++++---- .../PgTargetDriverDialect.java | 8 ++++---- .../TargetDriverDialect.java | 2 +- .../HikariPooledConnectionProviderTest.java | 10 +++++----- 10 files changed, 51 insertions(+), 51 deletions(-) rename wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/{ConnectInfo.java => ConnectParams.java} (80%) diff --git a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java index 9c055f407..6ad843baa 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/C3P0PooledConnectionProvider.java @@ -31,7 +31,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.targetdriverdialect.ConnectInfo; +import software.amazon.jdbc.targetdriverdialect.ConnectParams; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.connection.ConnectConfig; @@ -99,31 +99,31 @@ protected ComboPooledDataSource createDataSource( @NonNull ConnectConfig connectConfig, @NonNull HostSpec hostSpec, @NonNull Properties props) { - ConnectInfo connectInfo; + ConnectParams connectParams; try { - connectInfo = connectConfig.getDriverDialect() - .prepareConnectInfo(connectConfig.getProtocol(), hostSpec, props); + connectParams = + connectConfig.getDriverDialect().prepareConnectParams(connectConfig.getProtocol(), hostSpec, props); } catch (SQLException ex) { throw new RuntimeException(ex); } - final StringBuilder urlBuilder = new StringBuilder(connectInfo.url); + final StringBuilder urlBuilder = new StringBuilder(connectParams.connectionString); final StringJoiner propsJoiner = new StringJoiner("&"); - connectInfo.props.forEach((k, v) -> { + connectParams.props.forEach((k, v) -> { if (!PropertyDefinition.PASSWORD.name.equals(k) && !PropertyDefinition.USER.name.equals(k)) { propsJoiner.add(k + "=" + v); } }); - urlBuilder.append(connectInfo.url.contains("?") ? "&" : "?").append(propsJoiner); + urlBuilder.append(connectParams.connectionString.contains("?") ? "&" : "?").append(propsJoiner); ComboPooledDataSource ds = new ComboPooledDataSource(); ds.setJdbcUrl(urlBuilder.toString()); - final String user = connectInfo.props.getProperty(PropertyDefinition.USER.name); - final String password = connectInfo.props.getProperty(PropertyDefinition.PASSWORD.name); + final String user = connectParams.props.getProperty(PropertyDefinition.USER.name); + final String password = connectParams.props.getProperty(PropertyDefinition.PASSWORD.name); if (user != null) { ds.setUser(user); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java index 306b8b5f2..36f806501 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/DriverConnectionProvider.java @@ -29,7 +29,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.exceptions.SQLLoginException; -import software.amazon.jdbc.targetdriverdialect.ConnectInfo; +import software.amazon.jdbc.targetdriverdialect.ConnectParams; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.PropertyUtils; import software.amazon.jdbc.util.RdsUtils; @@ -98,18 +98,18 @@ public HostSpec getHostSpecByStrategy( public Connection connect(final @NonNull ConnectConfig connectConfig, final @NonNull HostSpec hostSpec) throws SQLException { final Properties propsCopy = PropertyUtils.copyProperties(connectConfig.getProps()); - final ConnectInfo connectInfo = - connectConfig.getDriverDialect().prepareConnectInfo(connectConfig.getProtocol(), hostSpec, propsCopy); + final ConnectParams connectParams = + connectConfig.getDriverDialect().prepareConnectParams(connectConfig.getProtocol(), hostSpec, propsCopy); connectConfig.getDbDialect().prepareConnectProperties(propsCopy, connectConfig.getProtocol(), hostSpec); - LOGGER.finest(() -> "Connecting to " + connectInfo.url + LOGGER.finest(() -> "Connecting to " + connectParams.connectionString + PropertyUtils.logProperties( - PropertyUtils.maskProperties(connectInfo.props), + PropertyUtils.maskProperties(connectParams.props), "\nwith properties: \n")); Connection conn; try { - conn = this.driver.connect(connectInfo.url, connectInfo.props); + conn = this.driver.connect(connectParams.connectionString, connectParams.props); } catch (Throwable throwable) { @@ -158,15 +158,15 @@ public Connection connect(final @NonNull ConnectConfig connectConfig, final @Non .host(fixedHost) .build(); - final ConnectInfo fixedConnectInfo = connectConfig.getDriverDialect().prepareConnectInfo( + final ConnectParams fixedConnectParams = connectConfig.getDriverDialect().prepareConnectParams( connectConfig.getProtocol(), connectionHostSpec, propsCopy); - LOGGER.finest(() -> "Connecting to " + fixedConnectInfo.url + LOGGER.finest(() -> "Connecting to " + fixedConnectParams.connectionString + " after correcting the hostname from " + originalHost + PropertyUtils.logProperties( - PropertyUtils.maskProperties(fixedConnectInfo.props), "\nwith properties: \n")); + PropertyUtils.maskProperties(fixedConnectParams.props), "\nwith properties: \n")); - conn = this.driver.connect(fixedConnectInfo.url, fixedConnectInfo.props); + conn = this.driver.connect(fixedConnectParams.connectionString, fixedConnectParams.props); } if (conn == null) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java index 032924c65..21a833b46 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HikariPooledConnectionProvider.java @@ -35,7 +35,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.cleanup.CanReleaseResources; -import software.amazon.jdbc.targetdriverdialect.ConnectInfo; +import software.amazon.jdbc.targetdriverdialect.ConnectParams; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.Pair; import software.amazon.jdbc.util.PropertyUtils; @@ -317,24 +317,24 @@ protected void configurePool( final Properties connectionProps) { final Properties copy = PropertyUtils.copyProperties(connectionProps); - ConnectInfo connectInfo; + ConnectParams connectParams; try { - connectInfo = connectConfig.getDriverDialect().prepareConnectInfo( + connectParams = connectConfig.getDriverDialect().prepareConnectParams( connectConfig.getProtocol(), hostSpec, copy); } catch (SQLException ex) { throw new RuntimeException(ex); } - StringBuilder urlBuilder = new StringBuilder(connectInfo.url); + StringBuilder urlBuilder = new StringBuilder(connectParams.connectionString); final StringJoiner propsJoiner = new StringJoiner("&"); - connectInfo.props.forEach((k, v) -> { + connectParams.props.forEach((k, v) -> { if (!PropertyDefinition.PASSWORD.name.equals(k) && !PropertyDefinition.USER.name.equals(k)) { propsJoiner.add(k + "=" + v); } }); - if (connectInfo.url.contains("?")) { + if (connectParams.connectionString.contains("?")) { urlBuilder.append("&").append(propsJoiner); } else { urlBuilder.append("?").append(propsJoiner); @@ -342,8 +342,8 @@ protected void configurePool( config.setJdbcUrl(urlBuilder.toString()); - final String user = connectInfo.props.getProperty(PropertyDefinition.USER.name); - final String password = connectInfo.props.getProperty(PropertyDefinition.PASSWORD.name); + final String user = connectParams.props.getProperty(PropertyDefinition.USER.name); + final String password = connectParams.props.getProperty(PropertyDefinition.PASSWORD.name); if (user != null) { config.setUsername(user); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/ConnectInfo.java b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/ConnectParams.java similarity index 80% rename from wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/ConnectInfo.java rename to wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/ConnectParams.java index 06e93947d..b129e0940 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/ConnectInfo.java +++ b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/ConnectParams.java @@ -18,12 +18,12 @@ import java.util.Properties; -public class ConnectInfo { - public String url; +public class ConnectParams { + public String connectionString; public Properties props; - public ConnectInfo(final String url, final Properties props) { - this.url = url; + public ConnectParams(final String connectionString, final Properties props) { + this.connectionString = connectionString; this.props = props; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/GenericTargetDriverDialect.java b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/GenericTargetDriverDialect.java index e740df215..4ecceeae4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/GenericTargetDriverDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/GenericTargetDriverDialect.java @@ -186,7 +186,7 @@ public boolean isDialect(String dataSourceClass) { } @Override - public ConnectInfo prepareConnectInfo(final @NonNull String protocol, + public ConnectParams prepareConnectParams(final @NonNull String protocol, final @NonNull HostSpec hostSpec, final @NonNull Properties props) throws SQLException { @@ -200,7 +200,7 @@ public ConnectInfo prepareConnectInfo(final @NonNull String protocol, // and use them to make a connection PropertyDefinition.removeAllExceptCredentials(props); - return new ConnectInfo(urlBuilder, props); + return new ConnectParams(urlBuilder, props); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbTargetDriverDialect.java b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbTargetDriverDialect.java index 506a249d2..af986e016 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbTargetDriverDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MariadbTargetDriverDialect.java @@ -109,9 +109,9 @@ public boolean isDialect(String dataSourceClass) { } @Override - public ConnectInfo prepareConnectInfo(final @NonNull String protocol, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props) throws SQLException { + public ConnectParams prepareConnectParams(final @NonNull String protocol, + final @NonNull HostSpec hostSpec, + final @NonNull Properties props) throws SQLException { final String databaseName = PropertyDefinition.DATABASE.getString(props) != null @@ -133,7 +133,7 @@ public ConnectInfo prepareConnectInfo(final @NonNull String protocol, String urlBuilder = protocol + hostSpec.getUrl() + databaseName + (permitMysqlSchemeFlag ? "?" + PERMIT_MYSQL_SCHEME : ""); - return new ConnectInfo(urlBuilder, props); + return new ConnectParams(urlBuilder, props); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MysqlConnectorJTargetDriverDialect.java b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MysqlConnectorJTargetDriverDialect.java index 3548fcddf..09b1a7cf1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MysqlConnectorJTargetDriverDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/MysqlConnectorJTargetDriverDialect.java @@ -67,9 +67,9 @@ public boolean isDialect(String dataSourceClass) { } @Override - public ConnectInfo prepareConnectInfo(final @NonNull String protocol, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props) throws SQLException { + public ConnectParams prepareConnectParams(final @NonNull String protocol, + final @NonNull HostSpec hostSpec, + final @NonNull Properties props) throws SQLException { final String databaseName = PropertyDefinition.DATABASE.getString(props) != null @@ -86,7 +86,7 @@ public ConnectInfo prepareConnectInfo(final @NonNull String protocol, PropertyDefinition.SOCKET_TIMEOUT.name, PropertyDefinition.CONNECT_TIMEOUT.name); - return new ConnectInfo(urlBuilder, props); + return new ConnectParams(urlBuilder, props); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/PgTargetDriverDialect.java b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/PgTargetDriverDialect.java index b4c56f334..b419190f1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/PgTargetDriverDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/PgTargetDriverDialect.java @@ -96,9 +96,9 @@ public boolean isDialect(String dataSourceClass) { } @Override - public ConnectInfo prepareConnectInfo(final @NonNull String protocol, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props) throws SQLException { + public ConnectParams prepareConnectParams(final @NonNull String protocol, + final @NonNull HostSpec hostSpec, + final @NonNull Properties props) throws SQLException { final String databaseName = PropertyDefinition.DATABASE.getString(props) != null @@ -133,7 +133,7 @@ public ConnectInfo prepareConnectInfo(final @NonNull String protocol, String urlBuilder = protocol + hostSpec.getUrl() + databaseName; - return new ConnectInfo(urlBuilder, props); + return new ConnectParams(urlBuilder, props); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/TargetDriverDialect.java b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/TargetDriverDialect.java index 1e3ebdcec..a9bad20b3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/TargetDriverDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/targetdriverdialect/TargetDriverDialect.java @@ -31,7 +31,7 @@ public interface TargetDriverDialect { boolean isDialect(final String dataSourceClass); - ConnectInfo prepareConnectInfo(final @NonNull String protocol, + ConnectParams prepareConnectParams(final @NonNull String protocol, final @NonNull HostSpec hostSpec, final @NonNull Properties props) throws SQLException; diff --git a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java index d2a2854a4..1a7a739af 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/HikariPooledConnectionProviderTest.java @@ -47,7 +47,7 @@ import org.mockito.MockitoAnnotations; import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; -import software.amazon.jdbc.targetdriverdialect.ConnectInfo; +import software.amazon.jdbc.targetdriverdialect.ConnectParams; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.Pair; import software.amazon.jdbc.util.connection.ConnectConfig; @@ -141,8 +141,8 @@ void testConnectWithDefaultMapping() throws SQLException { provider = spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig)); doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any()); - doReturn(new ConnectInfo("url", new Properties())) - .when(mockDriverDialect).prepareConnectInfo(anyString(), any(), any()); + doReturn(new ConnectParams("url", new Properties())) + .when(mockDriverDialect).prepareConnectParams(anyString(), any(), any()); try (Connection conn = provider.connect(mockConnectConfig, mockHostSpec)) { assertEquals(mockConnection, conn); @@ -228,8 +228,8 @@ public void testConfigurePool() throws SQLException { provider = new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig); final String expectedJdbcUrl = protocol + readerHost1Connection.getUrl() + db + "?database=" + db; - doReturn(new ConnectInfo(protocol + readerHost1Connection.getUrl() + db, defaultProps)) - .when(mockDriverDialect).prepareConnectInfo(anyString(), any(), any()); + doReturn(new ConnectParams(protocol + readerHost1Connection.getUrl() + db, defaultProps)) + .when(mockDriverDialect).prepareConnectParams(anyString(), any(), any()); provider.configurePool(mockConfig, mockConnectConfig, readerHost1Connection, defaultProps); verify(mockConfig).setJdbcUrl(expectedJdbcUrl);