Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import java.time.Duration;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.utils.AttributeMap;

@Log4j2
public class MLHttpClientFactory {
Expand All @@ -20,7 +22,8 @@ public static SdkAsyncHttpClient getAsyncHttpClient(
Duration connectionTimeout,
Duration readTimeout,
int maxConnections,
boolean connectorPrivateIpEnabled
boolean connectorPrivateIpEnabled,
boolean skipSslVerification
) {
return doPrivileged(() -> {
log
Expand All @@ -35,7 +38,9 @@ public static SdkAsyncHttpClient getAsyncHttpClient(
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.maxConcurrency(maxConnections)
.build();
.buildWithDefaults(
AttributeMap.builder().put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, skipSslVerification).build()
);
return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ public class MLHttpClientFactoryTests {

@Test
public void test_getSdkAsyncHttpClient_success() {
SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false);
SdkAsyncHttpClient client = MLHttpClientFactory
.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false, false);
assertNotNull(client);
client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false, true);
assertNotNull(client);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,15 @@ protected SdkAsyncHttpClient getHttpClient() {
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
boolean skipSslVerification = false;
if (connector.getParameters() != null && connector.getParameters().containsKey(SKIP_SSL_VERIFICATION)) {
skipSslVerification = Boolean.parseBoolean(connector.getParameters().get(SKIP_SSL_VERIFICATION));
}
this.httpClientRef
.compareAndSet(
null,
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
MLHttpClientFactory
.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled, skipSslVerification)
);
}
return httpClientRef.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,15 @@ protected SdkAsyncHttpClient getHttpClient() {
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
boolean skipSslVerification = false;
if (connector.getParameters() != null && connector.getParameters().containsKey(SKIP_SSL_VERIFICATION)) {
skipSslVerification = Boolean.parseBoolean(connector.getParameters().get(SKIP_SSL_VERIFICATION));
}
this.httpClientRef
.compareAndSet(
null,
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
MLHttpClientFactory
.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled, skipSslVerification)
);
}
return httpClientRef.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
public interface RemoteConnectorExecutor {

public String RETRY_EXECUTOR = "opensearch_ml_predict_remote";
String SKIP_SSL_VERIFICATION = "skip_ssl_verification";

default void executeAction(String action, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
executeAction(action, mlInput, actionListener, null);
Expand Down
Loading