From acb3babb87dad84bb46affb44fcba54e1ea36bd4 Mon Sep 17 00:00:00 2001 From: Steve Kriss Date: Wed, 13 Sep 2017 15:53:35 -0700 Subject: [PATCH] when restoring azureDisk from snapshot, update the diskURI with the new diskName Signed-off-by: Steve Kriss --- pkg/util/kube/utils.go | 18 +++++- pkg/util/kube/utils_test.go | 108 ++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 pkg/util/kube/utils_test.go diff --git a/pkg/util/kube/utils.go b/pkg/util/kube/utils.go index a9b5b0b58..af8bbb699 100644 --- a/pkg/util/kube/utils.go +++ b/pkg/util/kube/utils.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "regexp" + "strings" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -103,15 +104,28 @@ func GetPVSource(spec map[string]interface{}) (string, map[string]interface{}) { } // SetVolumeID looks for a supported PV source within the provided PV spec data. -// If sets the appropriate ID field within the source if found, and returns an +// If sets the appropriate ID field(s) within the source if found, and returns an // error if a supported PV source is not found. func SetVolumeID(spec map[string]interface{}, volumeID string) error { sourceType, source := GetPVSource(spec) - if sourceType == "" { return errors.New("persistent volume source is not compatible") } + // for azureDisk, we need to do a find-replace within the diskURI (if it exists) + // to switch the old disk name with the new. + if sourceType == "azureDisk" { + uri, err := collections.GetString(source, "diskURI") + if err == nil { + priorVolumeID, err := collections.GetString(source, supportedVolumeTypes["azureDisk"]) + if err != nil { + return err + } + + source["diskURI"] = strings.Replace(uri, priorVolumeID, volumeID, -1) + } + } + source[supportedVolumeTypes[sourceType]] = volumeID return nil diff --git a/pkg/util/kube/utils_test.go b/pkg/util/kube/utils_test.go new file mode 100644 index 000000000..3a06c6732 --- /dev/null +++ b/pkg/util/kube/utils_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2017 Heptio Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kube + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/heptio/ark/pkg/util/collections" +) + +func TestSetVolumeID(t *testing.T) { + tests := []struct { + name string + spec map[string]interface{} + volumeID string + expectedErr error + specFieldExpectations map[string]string + }{ + { + name: "awsElasticBlockStore normal case", + spec: map[string]interface{}{ + "awsElasticBlockStore": map[string]interface{}{ + "volumeID": "vol-old", + }, + }, + volumeID: "vol-new", + expectedErr: nil, + }, + { + name: "gcePersistentDisk normal case", + spec: map[string]interface{}{ + "gcePersistentDisk": map[string]interface{}{ + "pdName": "old-pd", + }, + }, + volumeID: "new-pd", + expectedErr: nil, + }, + { + name: "azureDisk normal case", + spec: map[string]interface{}{ + "azureDisk": map[string]interface{}{ + "diskName": "old-disk", + "diskURI": "some-nonsense/old-disk", + }, + }, + volumeID: "new-disk", + expectedErr: nil, + specFieldExpectations: map[string]string{ + "azureDisk.diskURI": "some-nonsense/new-disk", + }, + }, + { + name: "azureDisk with no diskURI", + spec: map[string]interface{}{ + "azureDisk": map[string]interface{}{ + "diskName": "old-disk", + }, + }, + volumeID: "new-disk", + expectedErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := SetVolumeID(test.spec, test.volumeID) + + require.Equal(t, test.expectedErr, err) + + if test.expectedErr != nil { + return + } + + pv := map[string]interface{}{ + "spec": test.spec, + } + + volumeID, err := GetVolumeID(pv) + require.Nil(t, err) + + assert.Equal(t, test.volumeID, volumeID) + + for path, expected := range test.specFieldExpectations { + actual, err := collections.GetString(test.spec, path) + assert.Nil(t, err) + assert.Equal(t, expected, actual) + } + }) + } +}