Add mock for getChecksum

pull/11451/head
Peixuan Ding 2021-05-22 00:32:19 -04:00
parent 5cf72b6eab
commit b9292bde0d
2 changed files with 55 additions and 17 deletions

View File

@ -32,6 +32,8 @@ func TestDownload(t *testing.T) {
t.Run("PreloadDownloadPreventsMultipleDownload", testPreloadDownloadPreventsMultipleDownload)
t.Run("ImageToCache", testImageToCache)
t.Run("ImageToDaemon", testImageToDaemon)
t.Run("PreloadNotExists", testPreloadNotExists)
t.Run("PreloadChecksumMismatch", testPreloadChecksumMismatch)
}
// Returns a mock function that sleeps before incrementing `downloadsCounter` and creates the requested file.
@ -85,8 +87,8 @@ func testPreloadDownloadPreventsMultipleDownload(t *testing.T) {
return nil, nil
}
checkPreloadExists = func(k8sVersion, containerRuntime string, forcePreload ...bool) bool { return true }
getChecksum = func(k8sVersion, containerRuntime string) (string, error) { return "check", nil }
ensureChecksumValid = func(k8sVersion, containerRuntime, path string) error { return nil }
getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) { return []byte("check"), nil }
ensureChecksumValid = func(k8sVersion, containerRuntime, path string, checksum []byte) error { return nil }
var group sync.WaitGroup
group.Add(2)
@ -107,6 +109,45 @@ func testPreloadDownloadPreventsMultipleDownload(t *testing.T) {
}
}
func testPreloadNotExists(t *testing.T) {
downloadNum := 0
DownloadMock = mockSleepDownload(&downloadNum)
checkCache = func(file string) (fs.FileInfo, error) { return nil, fmt.Errorf("cache not found") }
checkPreloadExists = func(k8sVersion, containerRuntime string, forcePreload ...bool) bool { return false }
getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) { return []byte("check"), nil }
ensureChecksumValid = func(k8sVersion, containerRuntime, path string, checksum []byte) error { return nil }
err := Preload(constants.DefaultKubernetesVersion, constants.DefaultContainerRuntime)
if err != nil {
t.Errorf("Expected no error when preload exists")
}
if downloadNum != 0 {
t.Errorf("Expected no download attempt but got %v!", downloadNum)
}
}
func testPreloadChecksumMismatch(t *testing.T) {
downloadNum := 0
DownloadMock = mockSleepDownload(&downloadNum)
checkCache = func(file string) (fs.FileInfo, error) { return nil, fmt.Errorf("cache not found") }
checkPreloadExists = func(k8sVersion, containerRuntime string, forcePreload ...bool) bool { return true }
getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) { return []byte("check"), nil }
ensureChecksumValid = func(k8sVersion, containerRuntime, path string, checksum []byte) error {
return fmt.Errorf("checksum mismatch")
}
err := Preload(constants.DefaultKubernetesVersion, constants.DefaultContainerRuntime)
expectedErrMsg := "checksum mismatch"
if err == nil {
t.Errorf("Expected error when checksum mismatches")
} else if err.Error() != expectedErrMsg {
t.Errorf("Expected error to be %s, got %s", expectedErrMsg, err.Error())
}
}
func testImageToCache(t *testing.T) {
downloadNum := 0
DownloadMock = mockSleepDownload(&downloadNum)

View File

@ -163,15 +163,16 @@ func Preload(k8sVersion, containerRuntime string) error {
return errors.Wrap(err, "tempfile")
}
targetPath = tmp.Name()
} else if checksum != "" {
url += "?checksum=" + checksum
} else if checksum != nil {
// add URL parameter for go-getter to automatically verify the checksum
url += fmt.Sprintf("?checksum=md5:%s", hex.EncodeToString(checksum))
}
if err := download(url, targetPath); err != nil {
return errors.Wrapf(err, "download failed: %s", url)
}
if err := ensureChecksumValid(k8sVersion, containerRuntime, targetPath); err != nil {
if err := ensureChecksumValid(k8sVersion, containerRuntime, targetPath, checksum); err != nil {
return err
}
@ -199,23 +200,19 @@ func getStorageAttrs(name string) (*storage.ObjectAttrs, error) {
return attrs, nil
}
var getChecksum = func(k8sVersion, containerRuntime string) (string, error) {
// getChecksum returns the MD5 checksum of the preload tarball
var getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) {
klog.Infof("getting checksum for %s ...", TarballName(k8sVersion, containerRuntime))
attrs, err := getStorageAttrs(TarballName(k8sVersion, containerRuntime))
if err != nil {
return "", err
return nil, err
}
md5 := hex.EncodeToString(attrs.MD5)
return fmt.Sprintf("md5:%s", md5), nil
return attrs.MD5, nil
}
func saveChecksumFile(k8sVersion, containerRuntime string) error {
// saveChecksumFile saves the checksum to a local file for later verification
func saveChecksumFile(k8sVersion, containerRuntime string, checksum []byte) error {
klog.Infof("saving checksum for %s ...", TarballName(k8sVersion, containerRuntime))
attrs, err := getStorageAttrs(TarballName(k8sVersion, containerRuntime))
if err != nil {
return err
}
checksum := attrs.MD5
return ioutil.WriteFile(PreloadChecksumPath(k8sVersion, containerRuntime), checksum, 0o644)
}
@ -243,8 +240,8 @@ func verifyChecksum(k8sVersion, containerRuntime, path string) error {
}
// ensureChecksumValid saves and verifies local binary checksum matches remote binary checksum
var ensureChecksumValid = func(k8sVersion, containerRuntime, targetPath string) error {
if err := saveChecksumFile(k8sVersion, containerRuntime); err != nil {
var ensureChecksumValid = func(k8sVersion, containerRuntime, targetPath string, checksum []byte) error {
if err := saveChecksumFile(k8sVersion, containerRuntime, checksum); err != nil {
return errors.Wrap(err, "saving checksum file")
}