Add mock for getChecksum
parent
5cf72b6eab
commit
b9292bde0d
|
@ -32,6 +32,8 @@ func TestDownload(t *testing.T) {
|
|||
t.Run("PreloadDownloadPreventsMultipleDownload", testPreloadDownloadPreventsMultipleDownload)
|
||||
t.Run("ImageToCache", testImageToCache)
|
||||
t.Run("ImageToDaemon", testImageToDaemon)
|
||||
t.Run("PreloadNotExists", testPreloadNotExists)
|
||||
t.Run("PreloadChecksumMismatch", testPreloadChecksumMismatch)
|
||||
}
|
||||
|
||||
// Returns a mock function that sleeps before incrementing `downloadsCounter` and creates the requested file.
|
||||
|
@ -85,8 +87,8 @@ func testPreloadDownloadPreventsMultipleDownload(t *testing.T) {
|
|||
return nil, nil
|
||||
}
|
||||
checkPreloadExists = func(k8sVersion, containerRuntime string, forcePreload ...bool) bool { return true }
|
||||
getChecksum = func(k8sVersion, containerRuntime string) (string, error) { return "check", nil }
|
||||
ensureChecksumValid = func(k8sVersion, containerRuntime, path string) error { return nil }
|
||||
getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) { return []byte("check"), nil }
|
||||
ensureChecksumValid = func(k8sVersion, containerRuntime, path string, checksum []byte) error { return nil }
|
||||
|
||||
var group sync.WaitGroup
|
||||
group.Add(2)
|
||||
|
@ -107,6 +109,45 @@ func testPreloadDownloadPreventsMultipleDownload(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testPreloadNotExists(t *testing.T) {
|
||||
downloadNum := 0
|
||||
DownloadMock = mockSleepDownload(&downloadNum)
|
||||
|
||||
checkCache = func(file string) (fs.FileInfo, error) { return nil, fmt.Errorf("cache not found") }
|
||||
checkPreloadExists = func(k8sVersion, containerRuntime string, forcePreload ...bool) bool { return false }
|
||||
getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) { return []byte("check"), nil }
|
||||
ensureChecksumValid = func(k8sVersion, containerRuntime, path string, checksum []byte) error { return nil }
|
||||
|
||||
err := Preload(constants.DefaultKubernetesVersion, constants.DefaultContainerRuntime)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error when preload exists")
|
||||
}
|
||||
|
||||
if downloadNum != 0 {
|
||||
t.Errorf("Expected no download attempt but got %v!", downloadNum)
|
||||
}
|
||||
}
|
||||
|
||||
func testPreloadChecksumMismatch(t *testing.T) {
|
||||
downloadNum := 0
|
||||
DownloadMock = mockSleepDownload(&downloadNum)
|
||||
|
||||
checkCache = func(file string) (fs.FileInfo, error) { return nil, fmt.Errorf("cache not found") }
|
||||
checkPreloadExists = func(k8sVersion, containerRuntime string, forcePreload ...bool) bool { return true }
|
||||
getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) { return []byte("check"), nil }
|
||||
ensureChecksumValid = func(k8sVersion, containerRuntime, path string, checksum []byte) error {
|
||||
return fmt.Errorf("checksum mismatch")
|
||||
}
|
||||
|
||||
err := Preload(constants.DefaultKubernetesVersion, constants.DefaultContainerRuntime)
|
||||
expectedErrMsg := "checksum mismatch"
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when checksum mismatches")
|
||||
} else if err.Error() != expectedErrMsg {
|
||||
t.Errorf("Expected error to be %s, got %s", expectedErrMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func testImageToCache(t *testing.T) {
|
||||
downloadNum := 0
|
||||
DownloadMock = mockSleepDownload(&downloadNum)
|
||||
|
|
|
@ -163,15 +163,16 @@ func Preload(k8sVersion, containerRuntime string) error {
|
|||
return errors.Wrap(err, "tempfile")
|
||||
}
|
||||
targetPath = tmp.Name()
|
||||
} else if checksum != "" {
|
||||
url += "?checksum=" + checksum
|
||||
} else if checksum != nil {
|
||||
// add URL parameter for go-getter to automatically verify the checksum
|
||||
url += fmt.Sprintf("?checksum=md5:%s", hex.EncodeToString(checksum))
|
||||
}
|
||||
|
||||
if err := download(url, targetPath); err != nil {
|
||||
return errors.Wrapf(err, "download failed: %s", url)
|
||||
}
|
||||
|
||||
if err := ensureChecksumValid(k8sVersion, containerRuntime, targetPath); err != nil {
|
||||
if err := ensureChecksumValid(k8sVersion, containerRuntime, targetPath, checksum); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -199,23 +200,19 @@ func getStorageAttrs(name string) (*storage.ObjectAttrs, error) {
|
|||
return attrs, nil
|
||||
}
|
||||
|
||||
var getChecksum = func(k8sVersion, containerRuntime string) (string, error) {
|
||||
// getChecksum returns the MD5 checksum of the preload tarball
|
||||
var getChecksum = func(k8sVersion, containerRuntime string) ([]byte, error) {
|
||||
klog.Infof("getting checksum for %s ...", TarballName(k8sVersion, containerRuntime))
|
||||
attrs, err := getStorageAttrs(TarballName(k8sVersion, containerRuntime))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
md5 := hex.EncodeToString(attrs.MD5)
|
||||
return fmt.Sprintf("md5:%s", md5), nil
|
||||
return attrs.MD5, nil
|
||||
}
|
||||
|
||||
func saveChecksumFile(k8sVersion, containerRuntime string) error {
|
||||
// saveChecksumFile saves the checksum to a local file for later verification
|
||||
func saveChecksumFile(k8sVersion, containerRuntime string, checksum []byte) error {
|
||||
klog.Infof("saving checksum for %s ...", TarballName(k8sVersion, containerRuntime))
|
||||
attrs, err := getStorageAttrs(TarballName(k8sVersion, containerRuntime))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
checksum := attrs.MD5
|
||||
return ioutil.WriteFile(PreloadChecksumPath(k8sVersion, containerRuntime), checksum, 0o644)
|
||||
}
|
||||
|
||||
|
@ -243,8 +240,8 @@ func verifyChecksum(k8sVersion, containerRuntime, path string) error {
|
|||
}
|
||||
|
||||
// ensureChecksumValid saves and verifies local binary checksum matches remote binary checksum
|
||||
var ensureChecksumValid = func(k8sVersion, containerRuntime, targetPath string) error {
|
||||
if err := saveChecksumFile(k8sVersion, containerRuntime); err != nil {
|
||||
var ensureChecksumValid = func(k8sVersion, containerRuntime, targetPath string, checksum []byte) error {
|
||||
if err := saveChecksumFile(k8sVersion, containerRuntime, checksum); err != nil {
|
||||
return errors.Wrap(err, "saving checksum file")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue