Refactoring bolt usage of IDs (tasks)

It should be no more necessary to perform operations directly on ID bytes (pad vs unpad).
pull/10616/head
Leonardo Di Donato 2018-07-31 01:54:43 +02:00 committed by Leonardo Di Donato
parent b67d3123e2
commit 02a80a0665
1 changed files with 143 additions and 74 deletions

View File

@ -22,7 +22,6 @@ package bolt
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -97,61 +96,74 @@ func New(db *bolt.DB, rootBucket string) (*Store, error) {
} }
// CreateTask creates a task in the boltdb task store. // CreateTask creates a task in the boltdb task store.
func (s *Store) CreateTask(ctx context.Context, org, user platform.ID, script string) (platform.ID, error) { func (s *Store) CreateTask(ctx context.Context, org, user platform.ID, script string) (*platform.ID, error) {
o, err := backend.StoreValidator.CreateArgs(org, user, script) o, err := backend.StoreValidator.CreateArgs(org, user, script)
if err != nil { if err != nil {
return nil, err return nil, err
} }
id := make(platform.ID, 8) encOrg, err := org.Encode()
if err != nil {
return nil, err
}
encUser, err := user.Encode()
if err != nil {
return nil, err
}
var id platform.ID
err = s.db.Update(func(tx *bolt.Tx) error { err = s.db.Update(func(tx *bolt.Tx) error {
// get the root bucket // get the root bucket
b := tx.Bucket(s.bucket) b := tx.Bucket(s.bucket)
// Get ID // Get ID
idi, _ := b.NextSequence() // we ignore this err check, because this can't err inside an Update call idi, _ := b.NextSequence() // we ignore this err check, because this can't err inside an Update call
binary.BigEndian.PutUint64(id, idi) id = platform.ID(idi)
encID, err := id.Encode()
if err != nil {
return err
}
// write script // write script
err := b.Bucket(tasksPath).Put(id, []byte(script)) err = b.Bucket(tasksPath).Put(encID, []byte(script))
if err != nil { if err != nil {
return err return err
} }
// name // name
err = b.Bucket(nameByTaskID).Put(id, []byte(o.Name)) err = b.Bucket(nameByTaskID).Put(encID, []byte(o.Name))
if err != nil { if err != nil {
return err return err
} }
// org // org
orgB, err := b.Bucket(orgsPath).CreateBucketIfNotExists([]byte(org)) orgB, err := b.Bucket(orgsPath).CreateBucketIfNotExists(encOrg)
if err != nil { if err != nil {
return err return err
} }
err = orgB.Put(id, nil) err = orgB.Put(encID, nil)
if err != nil { if err != nil {
return err return err
} }
err = b.Bucket(orgByTaskID).Put(id, []byte(org)) err = b.Bucket(orgByTaskID).Put(encID, encOrg)
if err != nil { if err != nil {
return err return err
} }
// user // user
userB, err := b.Bucket(usersPath).CreateBucketIfNotExists([]byte(user)) userB, err := b.Bucket(usersPath).CreateBucketIfNotExists(encUser)
if err != nil { if err != nil {
return err return err
} }
err = userB.Put(id, nil) err = userB.Put(encID, nil)
if err != nil { if err != nil {
return err return err
} }
err = b.Bucket(userByTaskID).Put(id, []byte(user)) err = b.Bucket(userByTaskID).Put(encID, encUser)
if err != nil { if err != nil {
return err return err
} }
@ -166,12 +178,12 @@ func (s *Store) CreateTask(ctx context.Context, org, user platform.ID, script st
return err return err
} }
metaB := b.Bucket(taskMetaPath) metaB := b.Bucket(taskMetaPath)
return metaB.Put(id, stmBytes) return metaB.Put(encID, stmBytes)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return unpadID(id), nil return &id, nil
} }
// ModifyTask changes a task with a new script, it should error if the task does not exist. // ModifyTask changes a task with a new script, it should error if the task does not exist.
@ -180,19 +192,22 @@ func (s *Store) ModifyTask(ctx context.Context, id platform.ID, newScript string
return err return err
} }
paddedID := padID(id) encID, err := id.Encode()
if err != nil {
return err
}
return s.db.Update(func(tx *bolt.Tx) error { return s.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(s.bucket).Bucket(tasksPath) b := tx.Bucket(s.bucket).Bucket(tasksPath)
if v := b.Get(paddedID); v == nil { // this is so we can error if the task doesn't exist if v := b.Get(encID); v == nil { // this is so we can error if the task doesn't exist
return ErrNotFound return ErrNotFound
} }
return b.Put(paddedID, []byte(newScript)) return b.Put(encID, []byte(newScript))
}) })
} }
// ListTasks lists the tasks based on a filter. // ListTasks lists the tasks based on a filter.
func (s *Store) ListTasks(ctx context.Context, params backend.TaskSearchParams) ([]backend.StoreTask, error) { func (s *Store) ListTasks(ctx context.Context, params backend.TaskSearchParams) ([]backend.StoreTask, error) {
if len(params.Org) > 0 && len(params.User) > 0 { if params.Org.Valid() && params.User.Valid() {
return nil, errors.New("ListTasks: org and user filters are mutually exclusive") return nil, errors.New("ListTasks: org and user filters are mutually exclusive")
} }
@ -215,14 +230,22 @@ func (s *Store) ListTasks(ctx context.Context, params backend.TaskSearchParams)
err := s.db.View(func(tx *bolt.Tx) error { err := s.db.View(func(tx *bolt.Tx) error {
var c *bolt.Cursor var c *bolt.Cursor
b := tx.Bucket(s.bucket) b := tx.Bucket(s.bucket)
if len(params.Org) > 0 { if params.Org.Valid() {
orgB := b.Bucket(orgsPath).Bucket(params.Org) encOrgID, err := params.Org.Encode()
if err != nil {
return err
}
orgB := b.Bucket(orgsPath).Bucket(encOrgID)
if orgB == nil { if orgB == nil {
return ErrNotFound return ErrNotFound
} }
c = orgB.Cursor() c = orgB.Cursor()
} else if len(params.User) > 0 { } else if params.User.Valid() {
userB := b.Bucket(usersPath).Bucket(params.User) encUserID, err := params.User.Encode()
if err != nil {
return err
}
userB := b.Bucket(usersPath).Bucket(encUserID)
if userB == nil { if userB == nil {
return ErrNotFound return ErrNotFound
} }
@ -230,15 +253,27 @@ func (s *Store) ListTasks(ctx context.Context, params backend.TaskSearchParams)
} else { } else {
c = b.Bucket(tasksPath).Cursor() c = b.Bucket(tasksPath).Cursor()
} }
if len(params.After) > 0 { if params.After.Valid() {
c.Seek(padID(params.After)) encAfterID, err := params.After.Encode()
if err != nil {
return err
}
c.Seek(encAfterID)
for k, _ := c.Next(); k != nil && len(taskIDs) < lim; k, _ = c.Next() { for k, _ := c.Next(); k != nil && len(taskIDs) < lim; k, _ = c.Next() {
taskIDs = append(taskIDs, k) var id platform.ID
if err := id.Decode(k); err != nil {
return err
}
taskIDs = append(taskIDs, id)
} }
return nil return nil
} }
for k, _ := c.First(); k != nil && len(taskIDs) < lim; k, _ = c.Next() { for k, _ := c.First(); k != nil && len(taskIDs) < lim; k, _ = c.Next() {
taskIDs = append(taskIDs, k) var id platform.ID
if err := id.Decode(k); err != nil {
return err
}
taskIDs = append(taskIDs, id)
} }
return nil return nil
}) })
@ -259,34 +294,47 @@ func (s *Store) ListTasks(ctx context.Context, params backend.TaskSearchParams)
return ctx.Err() return ctx.Err()
default: default:
// TODO(docmerlin): change the setup to reduce the number of lookups to 1 or 2. // TODO(docmerlin): change the setup to reduce the number of lookups to 1 or 2.
paddedID := taskIDs[i] encID, err := taskIDs[i].Encode()
tasks[i].ID = unpadID(paddedID) if err != nil {
tasks[i].Script = string(b.Bucket(tasksPath).Get(paddedID)) return err
tasks[i].Name = string(b.Bucket(nameByTaskID).Get(paddedID)) }
tasks[i].ID = taskIDs[i]
tasks[i].Script = string(b.Bucket(tasksPath).Get(encID))
tasks[i].Name = string(b.Bucket(nameByTaskID).Get(encID))
} }
} }
if len(params.Org) > 0 { if params.Org.Valid() {
for i := range taskIDs { for i := range taskIDs {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
default: default:
paddedID := taskIDs[i] encID, err := taskIDs[i].Encode()
if err != nil {
return err
}
tasks[i].Org = params.Org tasks[i].Org = params.Org
tasks[i].User = b.Bucket(userByTaskID).Get(paddedID) if err := tasks[i].User.Decode(b.Bucket(userByTaskID).Get(encID)); err != nil {
return err
}
} }
} }
return nil return nil
} }
if len(params.User) > 0 { if params.User.Valid() {
for i := range taskIDs { for i := range taskIDs {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
default: default:
paddedID := taskIDs[i] encID, err := taskIDs[i].Encode()
if err != nil {
return err
}
tasks[i].User = params.User tasks[i].User = params.User
tasks[i].Org = b.Bucket(orgByTaskID).Get(paddedID) if err := tasks[i].Org.Decode(b.Bucket(orgByTaskID).Get(encID)); err != nil {
return err
}
} }
} }
return nil return nil
@ -296,9 +344,16 @@ func (s *Store) ListTasks(ctx context.Context, params backend.TaskSearchParams)
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
default: default:
paddedID := taskIDs[i] encID, err := taskIDs[i].Encode()
tasks[i].User = b.Bucket(userByTaskID).Get(paddedID) if err != nil {
tasks[i].Org = b.Bucket(orgByTaskID).Get(paddedID) return err
}
if err := tasks[i].User.Decode(b.Bucket(userByTaskID).Get(encID)); err != nil {
return err
}
if err := tasks[i].Org.Decode(b.Bucket(orgByTaskID).Get(encID)); err != nil {
return err
}
} }
} }
return nil return nil
@ -312,20 +367,27 @@ func (s *Store) ListTasks(ctx context.Context, params backend.TaskSearchParams)
func (s *Store) FindTaskByID(ctx context.Context, id platform.ID) (*backend.StoreTask, error) { func (s *Store) FindTaskByID(ctx context.Context, id platform.ID) (*backend.StoreTask, error) {
var stmBytes []byte var stmBytes []byte
var script []byte var script []byte
var userID []byte var userID platform.ID
var name []byte var name []byte
var org []byte var orgID platform.ID
paddedID := padID(id) encID, err := id.Encode()
err := s.db.View(func(tx *bolt.Tx) error { if err != nil {
return nil, err
}
err = s.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(s.bucket) b := tx.Bucket(s.bucket)
script = b.Bucket(tasksPath).Get(paddedID) script = b.Bucket(tasksPath).Get(encID)
if script == nil { if script == nil {
return ErrNotFound return ErrNotFound
} }
stmBytes = b.Bucket(taskMetaPath).Get(paddedID) stmBytes = b.Bucket(taskMetaPath).Get(encID)
userID = b.Bucket(userByTaskID).Get(paddedID) if err := userID.Decode(b.Bucket(userByTaskID).Get(encID)); err != nil {
name = b.Bucket(nameByTaskID).Get(paddedID) return err
org = b.Bucket(orgByTaskID).Get(paddedID) }
name = b.Bucket(nameByTaskID).Get(encID)
if orgID.Decode(b.Bucket(orgByTaskID).Get(encID)); err != nil {
return err
}
return nil return nil
}) })
if err == ErrNotFound { if err == ErrNotFound {
@ -342,8 +404,8 @@ func (s *Store) FindTaskByID(ctx context.Context, id platform.ID) (*backend.Stor
} }
return &backend.StoreTask{ return &backend.StoreTask{
ID: append([]byte(nil), id...), // copy of input id ID: id,
Org: org, Org: orgID,
User: userID, User: userID,
Name: string(name), Name: string(name),
Script: string(script), Script: string(script),
@ -376,38 +438,41 @@ func (s *Store) FindTaskMetaByID(ctx context.Context, id platform.ID) (*pb.Store
// DeleteTask deletes the task // DeleteTask deletes the task
func (s *Store) DeleteTask(ctx context.Context, id platform.ID) (deleted bool, err error) { func (s *Store) DeleteTask(ctx context.Context, id platform.ID) (deleted bool, err error) {
paddedID := padID(id) encID, err := id.Encode()
if err != nil {
return false, err
}
err = s.db.Batch(func(tx *bolt.Tx) error { err = s.db.Batch(func(tx *bolt.Tx) error {
b := tx.Bucket(s.bucket) b := tx.Bucket(s.bucket)
if check := b.Bucket(tasksPath).Get(paddedID); check == nil { if check := b.Bucket(tasksPath).Get(encID); check == nil {
return ErrNotFound return ErrNotFound
} }
if err := b.Bucket(taskMetaPath).Delete(paddedID); err != nil { if err := b.Bucket(taskMetaPath).Delete(encID); err != nil {
return err return err
} }
if err := b.Bucket(tasksPath).Delete(paddedID); err != nil { if err := b.Bucket(tasksPath).Delete(encID); err != nil {
return err return err
} }
user := b.Bucket(userByTaskID).Get(paddedID) user := b.Bucket(userByTaskID).Get(encID)
if len(user) > 0 { if len(user) > 0 {
if err := b.Bucket(usersPath).Bucket(user).Delete(paddedID); err != nil { if err := b.Bucket(usersPath).Bucket(user).Delete(encID); err != nil {
return err return err
} }
} }
if err := b.Bucket(userByTaskID).Delete(paddedID); err != nil { if err := b.Bucket(userByTaskID).Delete(encID); err != nil {
return err return err
} }
if err := b.Bucket(nameByTaskID).Delete(paddedID); err != nil { if err := b.Bucket(nameByTaskID).Delete(encID); err != nil {
return err return err
} }
org := b.Bucket(orgByTaskID).Get(paddedID) org := b.Bucket(orgByTaskID).Get(encID)
if len(org) > 0 { if len(org) > 0 {
if err := b.Bucket(orgsPath).Bucket(org).Delete(paddedID); err != nil { if err := b.Bucket(orgsPath).Bucket(org).Delete(encID); err != nil {
return err return err
} }
} }
return b.Bucket(orgByTaskID).Delete(paddedID) return b.Bucket(orgByTaskID).Delete(encID)
}) })
if err == ErrNotFound { if err == ErrNotFound {
return false, nil return false, nil
@ -420,12 +485,17 @@ func (s *Store) DeleteTask(ctx context.Context, id platform.ID) (deleted bool, e
// CreateRun adds `now` to the task's metaData if we have not exceeded 'max_concurrency'. // CreateRun adds `now` to the task's metaData if we have not exceeded 'max_concurrency'.
func (s *Store) CreateRun(ctx context.Context, taskID platform.ID, now int64) (backend.QueuedRun, error) { func (s *Store) CreateRun(ctx context.Context, taskID platform.ID, now int64) (backend.QueuedRun, error) {
queuedRun := backend.QueuedRun{TaskID: append([]byte(nil), taskID...), Now: now} queuedRun := backend.QueuedRun{TaskID: taskID, Now: now}
stm := pb.StoredTaskInternalMeta{} stm := pb.StoredTaskInternalMeta{}
paddedID := padID(taskID)
if err := s.db.Update(func(tx *bolt.Tx) error { if err := s.db.Update(func(tx *bolt.Tx) error {
encID, err := taskID.Encode()
if err != nil {
return err
}
b := tx.Bucket(s.bucket) b := tx.Bucket(s.bucket)
stmBytes := b.Bucket(taskMetaPath).Get(paddedID) stmBytes := b.Bucket(taskMetaPath).Get(encID)
if err := stm.Unmarshal(stmBytes); err != nil { if err := stm.Unmarshal(stmBytes); err != nil {
return err return err
} }
@ -449,11 +519,9 @@ func (s *Store) CreateRun(ctx context.Context, taskID platform.ID, now int64) (b
return err return err
} }
var runID [8]byte queuedRun.RunID = platform.ID(intID)
binary.BigEndian.PutUint64(runID[:], intID)
queuedRun.RunID = unpadID(runID[:])
return tx.Bucket(s.bucket).Bucket(taskMetaPath).Put(paddedID, stmBytes) return tx.Bucket(s.bucket).Bucket(taskMetaPath).Put(encID, stmBytes)
}); err != nil { }); err != nil {
return queuedRun, err return queuedRun, err
} }
@ -464,19 +532,20 @@ func (s *Store) CreateRun(ctx context.Context, taskID platform.ID, now int64) (b
// FinishRun removes runID from the list of running tasks and if its `now` is later then last completed update it. // FinishRun removes runID from the list of running tasks and if its `now` is later then last completed update it.
func (s *Store) FinishRun(ctx context.Context, taskID, runID platform.ID) error { func (s *Store) FinishRun(ctx context.Context, taskID, runID platform.ID) error {
stm := pb.StoredTaskInternalMeta{} stm := pb.StoredTaskInternalMeta{}
paddedID := padID(taskID) encID, err := taskID.Encode()
if err != nil {
intID := binary.BigEndian.Uint64(padID(runID)) return err
}
return s.db.Update(func(tx *bolt.Tx) error { return s.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(s.bucket) b := tx.Bucket(s.bucket)
stmBytes := b.Bucket(taskMetaPath).Get(paddedID) stmBytes := b.Bucket(taskMetaPath).Get(encID)
if err := stm.Unmarshal(stmBytes); err != nil { if err := stm.Unmarshal(stmBytes); err != nil {
return err return err
} }
found := false found := false
for i, runner := range stm.CurrentlyRunning { for i, runner := range stm.CurrentlyRunning {
if runner.RunID == intID { if platform.ID(runner.RunID) == runID {
found = true found = true
stm.CurrentlyRunning = append(stm.CurrentlyRunning[:i], stm.CurrentlyRunning[i+1:]...) stm.CurrentlyRunning = append(stm.CurrentlyRunning[:i], stm.CurrentlyRunning[i+1:]...)
if runner.NowTimestampUnix > stm.LastCompletedTimestampUnix { if runner.NowTimestampUnix > stm.LastCompletedTimestampUnix {
@ -494,7 +563,7 @@ func (s *Store) FinishRun(ctx context.Context, taskID, runID platform.ID) error
return err return err
} }
return tx.Bucket(s.bucket).Bucket(taskMetaPath).Put(paddedID, stmBytes) return tx.Bucket(s.bucket).Bucket(taskMetaPath).Put(encID, stmBytes)
}) })
} }