Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions pkg/cloudprovider/aws/aws.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package aws

import (
"errors"
"fmt"
"strings"
"time"
Expand All @@ -13,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
)
Expand Down Expand Up @@ -40,8 +40,15 @@ func instanceToProviderID(instance *autoscaling.Instance) string {
return fmt.Sprintf("aws:///%s/%s", *instance.AvailabilityZone, *instance.InstanceId)
}

func providerIDToInstanceID(providerID string) string {
return strings.Split(providerID, "/")[4]
func providerIDToInstanceID(providerID string) (string, error) {
if providerID == "" {
return "", fmt.Errorf("empty providerID, it may be set later by cloud controller")
}
parts := strings.Split(providerID, "/")
if len(parts) < 5 {
return "", fmt.Errorf("malformed providerID %s: expected at least 4 slashes", providerID)
}
return parts[4], nil
}

// CloudProvider providers an aws cloud provider implementation
Expand Down Expand Up @@ -136,8 +143,10 @@ type Instance struct {
func (c *CloudProvider) GetInstance(node *v1.Node) (cloudprovider.Instance, error) {
var instance *Instance

id := providerIDToInstanceID(node.Spec.ProviderID)

id, err := providerIDToInstanceID(node.Spec.ProviderID)
if err != nil {
return instance, errors.Wrap(err, "failed to get instance ID from provider ID")
}
input := &ec2.DescribeInstancesInput{
InstanceIds: []*string{&id},
}
Expand Down
26 changes: 19 additions & 7 deletions pkg/cloudprovider/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,21 @@ func TestInstanceToProviderId(t *testing.T) {
}

func TestProviderIdToInstanceId(t *testing.T) {
assert.Equal(t, "abc123", providerIDToInstanceID("aws:///us-east-1b/abc123"))
id, err := providerIDToInstanceID("aws:///us-east-1b/abc123")
assert.Nil(t, err)
assert.Equal(t, "abc123", id)
}

func TestProviderIdToInstanceIdEmpty(t *testing.T) {
_, err := providerIDToInstanceID("")
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "empty providerID, it may be set later by cloud controller")
}

func TestProviderIdToInstanceIdMalformed(t *testing.T) {
_, err := providerIDToInstanceID("fake://provider/id")
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "malformed providerID fake://provider/id: expected at least 4 slashes")
}

func newMockCloudProvider(nodeGroups []string, service *test.MockAutoscalingService, ec2Service *test.MockEc2Service) (*CloudProvider, error) {
Expand Down Expand Up @@ -195,10 +209,9 @@ func TestCreateTemplateOverrides_NoASG(t *testing.T) {
)
mockNodeGroup.provider = awsCloudProvider

_, error := createTemplateOverrides(mockNodeGroup)
_, err := createTemplateOverrides(mockNodeGroup)
errorMessage := "failed to get an ASG from DescribeAutoscalingGroups response"
e := errors.New(errorMessage)
assert.Equalf(t, e, error, "Expected error with message '%v'", errorMessage)
assert.EqualError(t, err, errorMessage)
}

func TestCreateTemplateOverrides_NoSubnetIDs(t *testing.T) {
Expand All @@ -219,10 +232,9 @@ func TestCreateTemplateOverrides_NoSubnetIDs(t *testing.T) {
)
mockNodeGroup.provider = awsCloudProvider

_, error := createTemplateOverrides(mockNodeGroup)
_, err := createTemplateOverrides(mockNodeGroup)
errorMessage := "failed to get any subnetIDs from DescribeAutoscalingGroups response"
e := errors.New(errorMessage)
assert.Equalf(t, e, error, "Expected error with message '%v'", errorMessage)
assert.EqualError(t, err, errorMessage)
}

func TestCreateTemplateOverrides_Success(t *testing.T) {
Expand Down
12 changes: 12 additions & 0 deletions pkg/cloudprovider/aws/cloud_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,15 @@ func TestCloudProvider_GetInstance(t *testing.T) {
})
}
}

func TestCloudProvider_GetInstance_No_Provider_ID(t *testing.T) {
nodeGroups := []string{"1"}
node := &v1.Node{
Spec: v1.NodeSpec{},
}
ec2Service := &test.MockEc2Service{}
awsCloudProvider, err := newMockCloudProvider(nodeGroups, nil, ec2Service)
assert.Nil(t, err)
_, err = awsCloudProvider.GetInstance(node)
assert.NotNil(t, err)
}