diff --git a/cmd/influxd/launcher/launcher.go b/cmd/influxd/launcher/launcher.go index b7fcb956c6..f3403d8bd3 100644 --- a/cmd/influxd/launcher/launcher.go +++ b/cmd/influxd/launcher/launcher.go @@ -940,11 +940,10 @@ func (m *Launcher) run(ctx context.Context, opts *InfluxdOpts) (err error) { ) } - notebookSvc, err := notebooks.NewService() - if err != nil { - m.log.Error("Failed to initialize notebook service", zap.Error(err)) - return err - } + notebookSvc := notebooks.NewService( + m.log.With(zap.String("service", "notebooks")), + m.sqlStore, + ) notebookServer := notebookTransport.NewNotebookHandler( m.log.With(zap.String("handler", "notebooks")), authorizer.NewNotebookService(notebookSvc), diff --git a/go.mod b/go.mod index d43159fa68..2f5ba7e379 100644 --- a/go.mod +++ b/go.mod @@ -52,11 +52,11 @@ require ( github.com/influxdata/pkg-config v0.2.7 github.com/influxdata/usage-client v0.0.0-20160829180054-6d3895376368 github.com/jessevdk/go-flags v1.4.0 + github.com/jmoiron/sqlx v1.3.4 github.com/jsternberg/zap-logfmt v1.2.0 github.com/jwilder/encoding v0.0.0-20170811194829-b4e1701a28ef github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect github.com/kevinburke/go-bindata v3.11.0+incompatible - github.com/lib/pq v1.2.0 // indirect github.com/mattn/go-isatty v0.0.12 github.com/mattn/go-sqlite3 v1.14.7 github.com/matttproud/golang_protobuf_extensions v1.0.1 diff --git a/go.sum b/go.sum index 80c78816b1..f8f6add0c2 100644 --- a/go.sum +++ b/go.sum @@ -354,6 +354,8 @@ github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGAR github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w= +github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -413,6 +415,7 @@ github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzp github.com/mattn/go-runewidth v0.0.7 h1:Ei8KR0497xHyKJPAv59M1dkC+rOZCMBJ+t3fZ+twI54= github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-tty v0.0.0-20180907095812-13ff1204f104 h1:d8RFOZ2IiFtFWBcKEHAFYJcPTf0wY5q0exFNJZVWa1U= diff --git a/kit/platform/id.go b/kit/platform/id.go index af6a8a4977..48d5c528a3 100644 --- a/kit/platform/id.go +++ b/kit/platform/id.go @@ -1,6 +1,7 @@ package platform import ( + "database/sql/driver" "encoding/binary" "encoding/hex" "strconv" @@ -143,3 +144,20 @@ func (i ID) MarshalText() ([]byte, error) { func (i *ID) UnmarshalText(b []byte) error { return i.Decode(b) } + +// Value implements the database/sql Valuer interface for adding IDs to a sql database. +func (i ID) Value() (driver.Value, error) { + return i.String(), nil +} + +// Scan implements the database/sql Scanner interface for retrieving IDs from a sql database. +func (i *ID) Scan(value interface{}) error { + switch v := value.(type) { + case int64: + return i.DecodeFromString(strconv.FormatInt(v, 10)) + case string: + return i.DecodeFromString(v) + default: + return ErrInvalidID + } +} diff --git a/notebook.go b/notebook.go index cca11f1570..fb301f7736 100644 --- a/notebook.go +++ b/notebook.go @@ -2,7 +2,10 @@ package influxdb import ( "context" + "database/sql/driver" + "encoding/json" "fmt" + "strings" "time" "github.com/influxdata/influxdb/v2/kit/platform" @@ -36,17 +39,38 @@ func fieldRequiredError(field string) error { // Notebook represents all visual and query data for a notebook. type Notebook struct { - OrgID platform.ID `json:"orgID"` - ID platform.ID `json:"id"` - Name string `json:"name"` - Spec NotebookSpec `json:"spec"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + OrgID platform.ID `json:"orgID" db:"org_id"` + ID platform.ID `json:"id" db:"id"` + Name string `json:"name" db:"name"` + Spec NotebookSpec `json:"spec" db:"spec"` + CreatedAt time.Time `json:"createdAt" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt" db:"updated_at"` } // NotebookSpec is an abitrary JSON object provided by the client. type NotebookSpec map[string]interface{} +// Value implements the database/sql Valuer interface for adding NotebookSpecs to the database. +func (s NotebookSpec) Value() (driver.Value, error) { + spec, err := json.Marshal(s) + if err != nil { + return nil, err + } + + return string(spec), nil +} + +// Scan implements the database/sql Scanner interface for retrieving NotebookSpecs from the database. +func (s *NotebookSpec) Scan(value interface{}) error { + var spec NotebookSpec + if err := json.NewDecoder(strings.NewReader(value.(string))).Decode(&spec); err != nil { + return err + } + + *s = spec + return nil +} + // NotebookService is the service contract for Notebooks. type NotebookService interface { GetNotebook(ctx context.Context, id platform.ID) (*Notebook, error) diff --git a/notebooks/fake_store.go b/notebooks/fake_store.go deleted file mode 100644 index 7ea023ccce..0000000000 --- a/notebooks/fake_store.go +++ /dev/null @@ -1,115 +0,0 @@ -// This file is a placeholder for an actual notebooks service implementation. -// For now it enables user experimentation with the UI in front of the notebooks -// backend server. - -package notebooks - -import ( - "context" - "time" - - "github.com/influxdata/influxdb/v2" - "github.com/influxdata/influxdb/v2/kit/platform" - "github.com/influxdata/influxdb/v2/snowflake" -) - -var _ influxdb.NotebookService = (*FakeStore)(nil) - -type FakeStore struct { - list map[string][]*influxdb.Notebook -} - -func NewService() (*FakeStore, error) { - return &FakeStore{ - list: make(map[string][]*influxdb.Notebook), - }, nil -} - -func (s *FakeStore) GetNotebook(ctx context.Context, id platform.ID) (*influxdb.Notebook, error) { - ns := []*influxdb.Notebook{} - - for _, nList := range s.list { - ns = append(ns, nList...) - } - - for _, n := range ns { - if n.ID == id { - return n, nil - } - } - - return nil, influxdb.ErrNotebookNotFound -} - -func (s *FakeStore) ListNotebooks(ctx context.Context, filter influxdb.NotebookListFilter) ([]*influxdb.Notebook, error) { - o := filter.OrgID - - ns, ok := s.list[o.String()] - if !ok { - return []*influxdb.Notebook{}, nil - } - - return ns, nil -} - -func (s *FakeStore) CreateNotebook(ctx context.Context, create *influxdb.NotebookReqBody) (*influxdb.Notebook, error) { - n := &influxdb.Notebook{ - OrgID: create.OrgID, - Name: create.Name, - Spec: create.Spec, - ID: snowflake.NewDefaultIDGenerator().ID(), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - idStr := create.OrgID.String() - c := s.list[idStr] - - ns := append(c, n) - s.list[idStr] = ns - - return n, nil -} - -func (s *FakeStore) DeleteNotebook(ctx context.Context, id platform.ID) error { - var foundOrg string - for org, nList := range s.list { - for _, b := range nList { - if b.ID == id { - foundOrg = org - } - } - } - - if foundOrg == "" { - return influxdb.ErrNotebookNotFound - } - - newNs := []*influxdb.Notebook{} - - for _, b := range s.list[foundOrg] { - if b.ID != id { - newNs = append(newNs, b) - } - } - - s.list[foundOrg] = newNs - return nil -} - -func (s *FakeStore) UpdateNotebook(ctx context.Context, id platform.ID, update *influxdb.NotebookReqBody) (*influxdb.Notebook, error) { - n, err := s.GetNotebook(ctx, id) - if err != nil { - return nil, err - } - - if update.Name != "" { - n.Name = update.Name - } - - if len(update.Spec) > 0 { - n.Spec = update.Spec - } - - return n, nil -} diff --git a/notebooks/service.go b/notebooks/service.go new file mode 100644 index 0000000000..fca9d58195 --- /dev/null +++ b/notebooks/service.go @@ -0,0 +1,152 @@ +package notebooks + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/kit/platform" + "github.com/influxdata/influxdb/v2/snowflake" + "github.com/influxdata/influxdb/v2/sqlite" + "go.uber.org/zap" +) + +var _ influxdb.NotebookService = (*Service)(nil) + +type Service struct { + store *sqlite.SqlStore + log *zap.Logger + idGenerator platform.IDGenerator +} + +func NewService(logger *zap.Logger, store *sqlite.SqlStore) *Service { + return &Service{ + store: store, + log: logger, + idGenerator: snowflake.NewIDGenerator(), + } +} + +func (s *Service) GetNotebook(ctx context.Context, id platform.ID) (*influxdb.Notebook, error) { + var n influxdb.Notebook + + query := ` + SELECT id, org_id, name, spec, created_at, updated_at + FROM notebooks WHERE id = $1` + + if err := s.store.DB.GetContext(ctx, &n, query, id); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, influxdb.ErrNotebookNotFound + } + + return nil, err + } + + return &n, nil +} + +// CreateNotebook creates a notebook. Note that this and all "write" operations on the database need to use the Mutex lock, +// since sqlite can only handle 1 concurrent write operation at a time. +func (s *Service) CreateNotebook(ctx context.Context, create *influxdb.NotebookReqBody) (*influxdb.Notebook, error) { + s.store.Mu.Lock() + defer s.store.Mu.Unlock() + + nowTime := time.Now().UTC() + n := influxdb.Notebook{ + ID: s.idGenerator.ID(), + OrgID: create.OrgID, + Name: create.Name, + Spec: create.Spec, + CreatedAt: nowTime, + UpdatedAt: nowTime, + } + + query := ` + INSERT INTO notebooks (id, org_id, name, spec, created_at, updated_at) + VALUES (:id, :org_id, :name, :spec, :created_at, :updated_at)` + + _, err := s.store.DB.NamedExecContext(ctx, query, &n) + if err != nil { + return nil, err + } + + // Ideally, the create query would use "RETURNING" in order to avoid making a separate query. + // Unfortunately this breaks the scanning of values into the result struct, so we have to make a separate + // SELECT request to return the result from the database. + return s.GetNotebook(ctx, n.ID) +} + +// UpdateNotebook updates a notebook. +func (s *Service) UpdateNotebook(ctx context.Context, id platform.ID, update *influxdb.NotebookReqBody) (*influxdb.Notebook, error) { + s.store.Mu.Lock() + defer s.store.Mu.Unlock() + + nowTime := time.Now().UTC() + n := influxdb.Notebook{ + ID: id, + OrgID: update.OrgID, + Name: update.Name, + Spec: update.Spec, + UpdatedAt: nowTime, + } + + query := ` + UPDATE notebooks SET org_id = :org_id, name = :name, spec = :spec, updated_at = :updated_at + WHERE id = :id` + + _, err := s.store.DB.NamedExecContext(ctx, query, &n) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, influxdb.ErrNotebookNotFound + } + + return nil, err + } + + return s.GetNotebook(ctx, n.ID) +} + +// DeleteNotebook deletes a notebook. +func (s *Service) DeleteNotebook(ctx context.Context, id platform.ID) error { + s.store.Mu.Lock() + defer s.store.Mu.Unlock() + + query := ` + DELETE FROM notebooks + WHERE id = $1` + + res, err := s.store.DB.ExecContext(ctx, query, id.String()) + if err != nil { + return err + } + + r, err := res.RowsAffected() + if err != nil { + return err + } + + if r == 0 { + return influxdb.ErrNotebookNotFound + } + + return nil +} + +// ListNotebooks lists notebooks matching the provided filter. Currently, only org_id is used in the filter. +// Future uses may support pagination via this filter as well. +func (s *Service) ListNotebooks(ctx context.Context, filter influxdb.NotebookListFilter) ([]*influxdb.Notebook, error) { + var ns []*influxdb.Notebook + + query := ` + SELECT id, org_id, name, spec, created_at, updated_at + FROM notebooks + WHERE org_id = $1` + + if err := s.store.DB.SelectContext(ctx, &ns, query, filter.OrgID); err != nil { + return nil, err + } + + return ns, nil +} diff --git a/notebooks/service_test.go b/notebooks/service_test.go new file mode 100644 index 0000000000..51d2b29920 --- /dev/null +++ b/notebooks/service_test.go @@ -0,0 +1,209 @@ +package notebooks + +import ( + "context" + "fmt" + "testing" + + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/snowflake" + "github.com/influxdata/influxdb/v2/sqlite" + "github.com/influxdata/influxdb/v2/sqlite/migrations" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +var ( + idGen = snowflake.NewIDGenerator() +) + +func TestCreateAndGetNotebook(t *testing.T) { + t.Parallel() + + svc, clean := newTestService(t) + defer clean(t) + ctx := context.Background() + + // getting an invalid id should return an error + got, err := svc.GetNotebook(ctx, idGen.ID()) + require.Nil(t, got) + require.ErrorIs(t, influxdb.ErrNotebookNotFound, err) + + testCreate := &influxdb.NotebookReqBody{ + OrgID: idGen.ID(), + Name: "some name", + Spec: map[string]interface{}{"hello": "goodbye"}, + } + + // create a notebook and assert the results + gotCreate, err := svc.CreateNotebook(ctx, testCreate) + require.NoError(t, err) + gotCreateBody := &influxdb.NotebookReqBody{ + OrgID: gotCreate.OrgID, + Name: gotCreate.Name, + Spec: gotCreate.Spec, + } + require.Equal(t, testCreate, gotCreateBody) + + // get the notebook with the ID that was created and assert the results + gotGet, err := svc.GetNotebook(ctx, gotCreate.ID) + require.NoError(t, err) + gotGetBody := &influxdb.NotebookReqBody{ + OrgID: gotGet.OrgID, + Name: gotGet.Name, + Spec: gotGet.Spec, + } + require.Equal(t, testCreate, gotGetBody) +} + +func TestUpdate(t *testing.T) { + t.Parallel() + + svc, clean := newTestService(t) + defer clean(t) + ctx := context.Background() + + testCreate := &influxdb.NotebookReqBody{ + OrgID: idGen.ID(), + Name: "some name", + Spec: map[string]interface{}{"hello": "goodbye"}, + } + + testUpdate := &influxdb.NotebookReqBody{ + OrgID: testCreate.OrgID, + Name: "a new name", + Spec: map[string]interface{}{"aloha": "aloha"}, + } + + // attempting to update a non-existant notebook should return an error + got, err := svc.UpdateNotebook(ctx, idGen.ID(), testUpdate) + require.Nil(t, got) + require.ErrorIs(t, influxdb.ErrNotebookNotFound, err) + + // create the notebook so updating it can be tested + gotCreate, err := svc.CreateNotebook(ctx, testCreate) + require.NoError(t, err) + gotCreateBody := &influxdb.NotebookReqBody{ + OrgID: gotCreate.OrgID, + Name: gotCreate.Name, + Spec: gotCreate.Spec, + } + require.Equal(t, testCreate, gotCreateBody) + + // try to update the notebook and assert the results + gotUpdate, err := svc.UpdateNotebook(ctx, gotCreate.ID, testUpdate) + require.NoError(t, err) + gotUpdateBody := &influxdb.NotebookReqBody{ + OrgID: gotUpdate.OrgID, + Name: gotUpdate.Name, + Spec: gotUpdate.Spec, + } + + require.Equal(t, testUpdate, gotUpdateBody) + require.Equal(t, gotCreate.ID, gotUpdate.ID) + require.Equal(t, gotCreate.CreatedAt, gotUpdate.CreatedAt) + require.NotEqual(t, gotUpdate.CreatedAt, gotUpdate.UpdatedAt) +} + +func TestDelete(t *testing.T) { + t.Parallel() + + svc, clean := newTestService(t) + defer clean(t) + ctx := context.Background() + + // attempting to delete a non-existant notebook should return an error + err := svc.DeleteNotebook(ctx, idGen.ID()) + fmt.Println(err) + require.ErrorIs(t, influxdb.ErrNotebookNotFound, err) + + testCreate := &influxdb.NotebookReqBody{ + OrgID: idGen.ID(), + Name: "some name", + Spec: map[string]interface{}{"hello": "goodbye"}, + } + + // create the notebook that we are going to try to delete + gotCreate, err := svc.CreateNotebook(ctx, testCreate) + require.NoError(t, err) + gotCreateBody := &influxdb.NotebookReqBody{ + OrgID: gotCreate.OrgID, + Name: gotCreate.Name, + Spec: gotCreate.Spec, + } + require.Equal(t, testCreate, gotCreateBody) + + // should be able to successfully delete the notebook now + err = svc.DeleteNotebook(ctx, gotCreate.ID) + require.NoError(t, err) + + // ensure the notebook no longer exists + _, err = svc.GetNotebook(ctx, gotCreate.ID) + require.ErrorIs(t, influxdb.ErrNotebookNotFound, err) +} + +func TestList(t *testing.T) { + t.Parallel() + + svc, clean := newTestService(t) + defer clean(t) + ctx := context.Background() + + orgID := idGen.ID() + + // selecting with no matches for org_id should return an empty list and no error + got, err := svc.ListNotebooks(ctx, influxdb.NotebookListFilter{OrgID: orgID}) + require.NoError(t, err) + require.Equal(t, 0, len(got)) + + // create some notebooks to test the list operation with + creates := []*influxdb.NotebookReqBody{ + { + OrgID: orgID, + Name: "some name", + Spec: map[string]interface{}{"hello": "goodbye"}, + }, + { + OrgID: orgID, + Name: "another name", + Spec: map[string]interface{}{"aloha": "aloha"}, + }, + { + OrgID: orgID, + Name: "some name", + Spec: map[string]interface{}{"hola": "adios"}, + }, + } + + for _, c := range creates { + _, err := svc.CreateNotebook(ctx, c) + require.NoError(t, err) + } + + // there should now be notebooks returned from ListNotebooks + got, err = svc.ListNotebooks(ctx, influxdb.NotebookListFilter{OrgID: orgID}) + require.NoError(t, err) + require.Equal(t, len(creates), len(got)) + + // make sure the elements from the returned list were from the list of notebooks to create + for _, n := range got { + require.Contains(t, creates, &influxdb.NotebookReqBody{ + OrgID: n.OrgID, + Name: n.Name, + Spec: n.Spec, + }) + } +} + +func newTestService(t *testing.T) (*Service, func(t *testing.T)) { + store, clean := sqlite.NewTestStore(t) + ctx := context.Background() + + sqliteMigrator := sqlite.NewMigrator(store, zap.NewNop()) + err := sqliteMigrator.Up(ctx, &migrations.All{}) + require.NoError(t, err) + + svc := NewService(zap.NewNop(), store) + + return svc, clean +} diff --git a/notebooks/transport/http.go b/notebooks/transport/http.go index e5f1db87d8..d8677599af 100644 --- a/notebooks/transport/http.go +++ b/notebooks/transport/http.go @@ -14,7 +14,8 @@ import ( ) const ( - prefixNotebooks = "/api/v2private/notebooks" + prefixNotebooks = "/api/v2private/notebooks" + allNotebooksJSONKey = "flows" ) var ( @@ -107,7 +108,11 @@ func (h *NotebookHandler) handleGetNotebooks(w http.ResponseWriter, r *http.Requ return } - h.api.Respond(w, r, http.StatusOK, l) + p := map[string][]*influxdb.Notebook{ + allNotebooksJSONKey: l, + } + + h.api.Respond(w, r, http.StatusOK, p) } // create a single notebook. diff --git a/notebooks/transport/http_test.go b/notebooks/transport/http_test.go index 3a297eb784..26217ebcbc 100644 --- a/notebooks/transport/http_test.go +++ b/notebooks/transport/http_test.go @@ -58,10 +58,10 @@ func TestNotebookHandler(t *testing.T) { res := doTestRequest(t, req, http.StatusOK, true) - got := []*influxdb.Notebook{} + got := map[string][]*influxdb.Notebook{} err := json.NewDecoder(res.Body).Decode(&got) require.NoError(t, err) - require.Equal(t, got, []*influxdb.Notebook{testNotebook}) + require.Equal(t, got[allNotebooksJSONKey], []*influxdb.Notebook{testNotebook}) }) t.Run("create notebook happy path", func(t *testing.T) { diff --git a/sqlite/migrator_test.go b/sqlite/migrator_test.go index 2a28b987c4..305a08f900 100644 --- a/sqlite/migrator_test.go +++ b/sqlite/migrator_test.go @@ -13,7 +13,7 @@ import ( func TestUp(t *testing.T) { t.Parallel() - store, clean := newTestStore(t) + store, clean := NewTestStore(t) defer clean(t) ctx := context.Background() diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index b6d9951a96..b43f893125 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -2,12 +2,12 @@ package sqlite import ( "context" - "database/sql" "fmt" "strconv" "sync" // sqlite3 driver + "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" "go.uber.org/zap" @@ -21,13 +21,13 @@ const ( // SqlStore is a wrapper around the db and provides basic functionality for maintaining the db // including flushing the data from the db during end-to-end testing. type SqlStore struct { - mu sync.Mutex - db *sql.DB + Mu sync.Mutex + DB *sqlx.DB log *zap.Logger } func NewSqlStore(path string, log *zap.Logger) (*SqlStore, error) { - db, err := sql.Open("sqlite3", path) + db, err := sqlx.Open("sqlite3", path) if err != nil { return nil, err } @@ -42,14 +42,14 @@ func NewSqlStore(path string, log *zap.Logger) (*SqlStore, error) { } return &SqlStore{ - db: db, + DB: db, log: log, }, nil } // Close the connection to the sqlite database func (s *SqlStore) Close() error { - err := s.db.Close() + err := s.DB.Close() if err != nil { return err } @@ -77,10 +77,10 @@ func (s *SqlStore) Flush(ctx context.Context) { func (s *SqlStore) execTrans(ctx context.Context, stmt string) error { // use a lock to prevent two potential simultaneous write operations to the database, // which would throw an error - s.mu.Lock() - defer s.mu.Unlock() + s.Mu.Lock() + defer s.Mu.Unlock() - tx, err := s.db.BeginTx(ctx, nil) + tx, err := s.DB.BeginTx(ctx, nil) if err != nil { return err } @@ -124,7 +124,7 @@ func (s *SqlStore) tableNames() ([]string, error) { func (s *SqlStore) queryToStrings(stmt string) ([]string, error) { var output []string - rows, err := s.db.Query(stmt) + rows, err := s.DB.Query(stmt) if err != nil { return nil, err } diff --git a/sqlite/sqlite_helpers.go b/sqlite/sqlite_helpers.go new file mode 100644 index 0000000000..f6d6e8d9de --- /dev/null +++ b/sqlite/sqlite_helpers.go @@ -0,0 +1,29 @@ +package sqlite + +import ( + "io/ioutil" + "os" + "testing" + + "go.uber.org/zap" +) + +func NewTestStore(t *testing.T) (*SqlStore, func(t *testing.T)) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("unable to create temporary test directory %v", err) + } + + cleanUpFn := func(t *testing.T) { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("unable to delete temporary test directory %s: %v", tempDir, err) + } + } + + s, err := NewSqlStore(tempDir+"/"+DefaultFilename, zap.NewNop()) + if err != nil { + t.Fatal("unable to open testing database") + } + + return s, cleanUpFn +} diff --git a/sqlite/sqlite_test.go b/sqlite/sqlite_test.go index 50deb6f481..58541d6b1c 100644 --- a/sqlite/sqlite_test.go +++ b/sqlite/sqlite_test.go @@ -2,19 +2,16 @@ package sqlite import ( "context" - "io/ioutil" - "os" "testing" "github.com/stretchr/testify/require" - "go.uber.org/zap" ) func TestFlush(t *testing.T) { t.Parallel() ctx := context.Background() - store, clean := newTestStore(t) + store, clean := NewTestStore(t) defer clean(t) err := store.execTrans(ctx, `CREATE TABLE test_table_1 (id TEXT NOT NULL PRIMARY KEY)`) @@ -37,7 +34,7 @@ func TestFlush(t *testing.T) { func TestUserVersion(t *testing.T) { t.Parallel() - store, clean := newTestStore(t) + store, clean := NewTestStore(t) defer clean(t) ctx := context.Background() @@ -52,7 +49,7 @@ func TestUserVersion(t *testing.T) { func TestTableNames(t *testing.T) { t.Parallel() - store, clean := newTestStore(t) + store, clean := NewTestStore(t) defer clean(t) ctx := context.Background() @@ -65,23 +62,3 @@ func TestTableNames(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"test_table_1", "test_table_3", "test_table_2"}, got) } - -func newTestStore(t *testing.T) (*SqlStore, func(t *testing.T)) { - tempDir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatalf("unable to create temporary test directory %v", err) - } - - cleanUpFn := func(t *testing.T) { - if err := os.RemoveAll(tempDir); err != nil { - t.Fatalf("unable to delete temporary test directory %s: %v", tempDir, err) - } - } - - s, err := NewSqlStore(tempDir+"/"+DefaultFilename, zap.NewNop()) - if err != nil { - t.Fatal("unable to open testing database") - } - - return s, cleanUpFn -}