pod restore action: check initContainers exist before processing

Signed-off-by: Steve Kriss <steve@heptio.com>
pull/927/head
Steve Kriss 2018-10-11 13:44:02 -06:00
parent d7dfffa373
commit 38e86ceff5
2 changed files with 119 additions and 113 deletions

View File

@ -51,97 +51,82 @@ func (a *podAction) Execute(obj runtime.Unstructured, restore *api.Restore) (run
a.logger.Debug("deleting spec.NodeName")
delete(spec, "nodeName")
// if there are no volumes, then there can't be any volume mounts, so we're done.
if !collections.Exists(spec, "volumes") {
return obj, nil, nil
}
serviceAccountName, err := collections.GetString(spec, "serviceAccountName")
if err != nil {
return nil, nil, err
}
prefix := serviceAccountName + "-token-"
newVolumes := make([]interface{}, 0)
// remove the service account token from volumes
a.logger.Debug("iterating over volumes")
err = collections.ForEach(spec, "volumes", func(volume map[string]interface{}) error {
name, err := collections.GetString(volume, "name")
if err != nil {
return err
}
a.logger.WithField("volumeName", name).Debug("Checking volume")
if strings.HasPrefix(name, serviceAccountName+"-token-") {
a.logger.WithField("volumeName", name).Debug("Excluding volume")
} else {
a.logger.WithField("volumeName", name).Debug("Preserving volume")
newVolumes = append(newVolumes, volume)
}
return nil
})
if err != nil {
if err := removeItemsWithNamePrefix(spec, "volumes", prefix, a.logger); err != nil {
return nil, nil, err
}
a.logger.Debug("Setting spec.volumes")
spec["volumes"] = newVolumes
// remove the service account token volume mount from all containers
a.logger.Debug("iterating over containers")
err = collections.ForEach(spec, "containers", func(container map[string]interface{}) error {
var newVolumeMounts []interface{}
err := collections.ForEach(container, "volumeMounts", func(volumeMount map[string]interface{}) error {
name, err := collections.GetString(volumeMount, "name")
if err != nil {
return err
}
a.logger.WithField("volumeMount", name).Debug("Checking volumeMount")
if strings.HasPrefix(name, serviceAccountName+"-token-") {
a.logger.WithField("volumeMount", name).Debug("Excluding volumeMount")
} else {
a.logger.WithField("volumeMount", name).Debug("Preserving volumeMount")
newVolumeMounts = append(newVolumeMounts, volumeMount)
}
return nil
})
if err != nil {
return err
}
container["volumeMounts"] = newVolumeMounts
return nil
})
if err != nil {
if err := removeVolumeMounts(spec, "containers", prefix, a.logger); err != nil {
return nil, nil, err
}
if !collections.Exists(spec, "initContainers") {
return obj, nil, nil
}
// remove the service account token volume mount from all init containers
a.logger.Debug("iterating over init containers")
err = collections.ForEach(spec, "initContainers", func(container map[string]interface{}) error {
var newVolumeMounts []interface{}
err := collections.ForEach(container, "volumeMounts", func(volumeMount map[string]interface{}) error {
name, err := collections.GetString(volumeMount, "name")
if err != nil {
return err
}
a.logger.WithField("volumeMount", name).Debug("Checking volumeMount")
if strings.HasPrefix(name, serviceAccountName+"-token-") {
a.logger.WithField("volumeMount", name).Debug("Excluding volumeMount")
} else {
a.logger.WithField("volumeMount", name).Debug("Preserving volumeMount")
newVolumeMounts = append(newVolumeMounts, volumeMount)
}
return nil
})
if err != nil {
return err
}
container["volumeMounts"] = newVolumeMounts
return nil
})
if err != nil {
if err := removeVolumeMounts(spec, "initContainers", prefix, a.logger); err != nil {
return nil, nil, err
}
return obj, nil, nil
}
// removeItemsWithNamePrefix iterates through the collection stored at 'key' in 'unstructuredObj'
// and removes any item that has a name that starts with 'prefix'.
func removeItemsWithNamePrefix(unstructuredObj map[string]interface{}, key, prefix string, log logrus.FieldLogger) error {
var preservedItems []interface{}
if err := collections.ForEach(unstructuredObj, key, func(item map[string]interface{}) error {
name, err := collections.GetString(item, "name")
if err != nil {
return err
}
singularKey := strings.TrimSuffix(key, "s")
log := log.WithField(singularKey, name)
log.Debug("Checking " + singularKey)
switch {
case strings.HasPrefix(name, prefix):
log.Debug("Excluding ", singularKey)
default:
log.Debug("Preserving ", singularKey)
preservedItems = append(preservedItems, item)
}
return nil
}); err != nil {
return err
}
unstructuredObj[key] = preservedItems
return nil
}
// removeVolumeMounts iterates through a slice of containers stored at 'containersKey' in
// 'podSpec' and removes any volume mounts with a name starting with 'prefix'.
func removeVolumeMounts(podSpec map[string]interface{}, containersKey, prefix string, log logrus.FieldLogger) error {
return collections.ForEach(podSpec, containersKey, func(container map[string]interface{}) error {
if !collections.Exists(container, "volumeMounts") {
return nil
}
return removeItemsWithNamePrefix(container, "volumeMounts", prefix, log)
})
}

View File

@ -41,50 +41,46 @@ func TestPodActionExecute(t *testing.T) {
name: "nodeName (only) should be deleted from spec",
obj: NewTestUnstructured().WithName("pod-1").WithSpec("nodeName", "foo").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{}).
WithSpecField("containers", []interface{}{}).
WithSpecField("initContainers", []interface{}{}).
Unstructured,
expectedErr: false,
expectedRes: NewTestUnstructured().WithName("pod-1").WithSpec("foo").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{}).
WithSpecField("containers", []interface{}{}).
WithSpecField("initContainers", []interface{}{}).
Unstructured,
},
{
name: "volumes matching prefix ServiceAccount-token- should be deleted",
name: "volumes matching prefix <service account name>-token- should be deleted",
obj: NewTestUnstructured().WithName("pod-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("initContainers", []interface{}{}).
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
map[string]interface{}{"name": "foo-token-foo"},
}).WithSpecField("containers", []interface{}{}).Unstructured,
}).
WithSpecField("containers", []interface{}{}).
Unstructured,
expectedErr: false,
expectedRes: NewTestUnstructured().WithName("pod-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("initContainers", []interface{}{}).
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
}).WithSpecField("containers", []interface{}{}).Unstructured,
}).
WithSpecField("containers", []interface{}{}).
Unstructured,
},
{
name: "container volumeMounts matching prefix ServiceAccount-token- should be deleted",
name: "container volumeMounts matching prefix <service account name>-token- should be deleted",
obj: NewTestUnstructured().WithName("svc-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{}).
WithSpecField("initContainers", []interface{}{}).
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
map[string]interface{}{"name": "foo-token-foo"},
}).
WithSpecField("containers", []interface{}{
map[string]interface{}{
"volumeMounts": []interface{}{
map[string]interface{}{
"name": "foo",
},
map[string]interface{}{
"name": "foo-token-foo",
},
map[string]interface{}{"name": "foo"},
map[string]interface{}{"name": "foo-token-foo"},
},
},
}).
@ -92,34 +88,32 @@ func TestPodActionExecute(t *testing.T) {
expectedErr: false,
expectedRes: NewTestUnstructured().WithName("svc-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{}).
WithSpecField("initContainers", []interface{}{}).
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
}).
WithSpecField("containers", []interface{}{
map[string]interface{}{
"volumeMounts": []interface{}{
map[string]interface{}{
"name": "foo",
},
map[string]interface{}{"name": "foo"},
},
},
}).
Unstructured,
},
{
name: "initContainer volumeMounts matching prefix ServiceAccount-token- should be deleted",
name: "initContainer volumeMounts matching prefix <service account name>-token- should be deleted",
obj: NewTestUnstructured().WithName("svc-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{}).
WithSpecField("containers", []interface{}{}).
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
map[string]interface{}{"name": "foo-token-foo"},
}).
WithSpecField("initContainers", []interface{}{
map[string]interface{}{
"volumeMounts": []interface{}{
map[string]interface{}{
"name": "foo",
},
map[string]interface{}{
"name": "foo-token-foo",
},
map[string]interface{}{"name": "foo"},
map[string]interface{}{"name": "foo-token-foo"},
},
},
}).
@ -127,30 +121,57 @@ func TestPodActionExecute(t *testing.T) {
expectedErr: false,
expectedRes: NewTestUnstructured().WithName("svc-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{}).
WithSpecField("containers", []interface{}{}).
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
}).
WithSpecField("initContainers", []interface{}{
map[string]interface{}{
"volumeMounts": []interface{}{
map[string]interface{}{
"name": "foo",
},
map[string]interface{}{"name": "foo"},
},
},
}).
Unstructured,
},
{
name: "containers and initContainers with no volume mounts should not error",
obj: NewTestUnstructured().WithName("pod-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
map[string]interface{}{"name": "foo-token-foo"},
}).
WithSpecField("containers", []interface{}{}).
WithSpecField("initContainers", []interface{}{}).
Unstructured,
expectedErr: false,
expectedRes: NewTestUnstructured().WithName("pod-1").
WithSpec("serviceAccountName", "foo").
WithSpecField("volumes", []interface{}{
map[string]interface{}{"name": "foo"},
}).
WithSpecField("containers", []interface{}{}).
WithSpecField("initContainers", []interface{}{}).
Unstructured,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
action := NewPodAction(arktest.NewLogger())
res, _, err := action.Execute(test.obj, nil)
res, warning, err := action.Execute(test.obj, nil)
if assert.Equal(t, test.expectedErr, err != nil) {
assert.Equal(t, test.expectedRes, res)
assert.Nil(t, warning)
if test.expectedErr {
assert.NotNil(t, err, "expected an error")
} else {
assert.Nil(t, err, "expected no error, got %v", err)
}
assert.Equal(t, test.expectedRes, res)
})
}
}