diff --git a/pkg/fs/fs_test.go b/pkg/fs/fs_test.go new file mode 100644 index 0000000000..d51c97a0b7 --- /dev/null +++ b/pkg/fs/fs_test.go @@ -0,0 +1,130 @@ +package fs_test + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/influxdata/influxdb/pkg/fs" +) + +func TestRenameFileWithReplacement(t *testing.T) { + t.Run("exists", func(t *testing.T) { + oldpath := MustCreateTempFile() + newpath := MustCreateTempFile() + defer MustRemoveAll(oldpath) + defer MustRemoveAll(newpath) + + oldContents := MustReadAllFile(oldpath) + newContents := MustReadAllFile(newpath) + + if got, exp := oldContents, oldpath; got != exp { + t.Fatalf("got contents %q, expected %q", got, exp) + } else if got, exp := newContents, newpath; got != exp { + t.Fatalf("got contents %q, expected %q", got, exp) + } + + if err := fs.RenameFileWithReplacement(oldpath, newpath); err != nil { + t.Fatalf("ReplaceFileIfExists returned an error: %s", err) + } + + if err := fs.SyncDir(filepath.Dir(oldpath)); err != nil { + panic(err) + } + + // Contents of newpath will now be equivalent to oldpath' contents. + newContents = MustReadAllFile(newpath) + if newContents != oldContents { + t.Fatalf("contents for files differ: %q versus %q", newContents, oldContents) + } + + // oldpath will be removed. + if MustFileExists(oldpath) { + t.Fatalf("file %q still exists, but it shouldn't", oldpath) + } + }) + + t.Run("not exists", func(t *testing.T) { + oldpath := MustCreateTempFile() + defer MustRemoveAll(oldpath) + + oldContents := MustReadAllFile(oldpath) + if got, exp := oldContents, oldpath; got != exp { + t.Fatalf("got contents %q, expected %q", got, exp) + } + + root := filepath.Dir(oldpath) + newpath := filepath.Join(root, "foo") + if err := fs.RenameFileWithReplacement(oldpath, newpath); err != nil { + t.Fatalf("ReplaceFileIfExists returned an error: %s", err) + } + + if err := fs.SyncDir(filepath.Dir(oldpath)); err != nil { + panic(err) + } + + // Contents of newpath will now be equivalent to oldpath's contents. + newContents := MustReadAllFile(newpath) + if newContents != oldContents { + t.Fatalf("contents for files differ: %q versus %q", newContents, oldContents) + } + + // oldpath will be removed. + if MustFileExists(oldpath) { + t.Fatalf("file %q still exists, but it shouldn't", oldpath) + } + }) +} + +// MustCreateTempFile creates a temporary file returning the path to the file. +// +// MustCreateTempFile writes the absolute path to the file into the file itself. +// It panics if there is an error. +func MustCreateTempFile() string { + f, err := ioutil.TempFile("", "fs-test") + if err != nil { + panic(fmt.Sprintf("failed to create temp file: %v", err)) + } + + name := f.Name() + f.WriteString(name) + if err := f.Close(); err != nil { + panic(err) + } + return name +} + +func MustRemoveAll(path string) { + if err := os.RemoveAll(path); err != nil { + panic(err) + } +} + +// MustFileExists determines if a file exists, panicking if any error +// (other than one associated with the file not existing) is returned. +func MustFileExists(path string) bool { + _, err := os.Stat(path) + if err == nil { + return true + } else if os.IsNotExist(err) { + return false + } + panic(err) +} + +// MustReadAllFile reads the contents of path, panicking if there is an error. +func MustReadAllFile(path string) string { + fd, err := os.Open(path) + if err != nil { + panic(err) + } + defer fd.Close() + + data, err := ioutil.ReadAll(fd) + if err != nil { + panic(err) + } + return string(data) +} diff --git a/pkg/fs/fs_unix.go b/pkg/fs/fs_unix.go index 7386516d59..38a969a50b 100644 --- a/pkg/fs/fs_unix.go +++ b/pkg/fs/fs_unix.go @@ -3,10 +3,25 @@ package fs import ( + "fmt" "os" "syscall" ) +// A FileExistsError is returned when an operation cannot be completed due to a +// file already existing. +type FileExistsError struct { + path string +} + +func newFileExistsError(path string) FileExistsError { + return FileExistsError{path: path} +} + +func (e FileExistsError) Error() string { + return fmt.Sprintf("operation not allowed, file %q exists", e.path) +} + // SyncDir flushes any file renames to the filesystem. func SyncDir(dirName string) error { // fsync the dir to flush the rename @@ -39,3 +54,14 @@ func SyncDir(dirName string) error { func RenameFileWithReplacement(oldpath, newpath string) error { return os.Rename(oldpath, newpath) } + +// RenameFile renames oldpath to newpath, returning an error if newpath already +// exists. If this function returns successfully, the contents of newpath will +// be identical to oldpath, and oldpath will be removed. +func RenameFile(oldpath, newpath string) error { + if _, err := os.Stat(newpath); err == nil { + return newFileExistsError(newpath) + } + + return os.Rename(oldpath, newpath) +} diff --git a/pkg/fs/fs_windows.go b/pkg/fs/fs_windows.go index 7b9a121659..8d92516d5a 100644 --- a/pkg/fs/fs_windows.go +++ b/pkg/fs/fs_windows.go @@ -21,3 +21,16 @@ func RenameFileWithReplacement(oldpath, newpath string) error { return os.Rename(oldpath, newpath) } + +// RenameFile renames oldpath to newpath, returning an error if newpath already +// exists. If this function returns successfully, the contents of newpath will +// be identical to oldpath, and oldpath will be removed. +func RenameFile(oldpath, newpath string) error { + if _, err := os.Stat(newpath); err == nil { + // os.Rename on Windows will return an error if the file exists, but it's + // preferable to keep the errors the same across platforms. + return newFileExistsError(newpath) + } + + return os.Rename(oldpath, newpath) +}