diff --git a/cmd/influxd/launcher/launcher.go b/cmd/influxd/launcher/launcher.go index faf031c036..2b9497e655 100644 --- a/cmd/influxd/launcher/launcher.go +++ b/cmd/influxd/launcher/launcher.go @@ -861,7 +861,7 @@ func (m *Launcher) run(ctx context.Context, opts *InfluxdOpts) (err error) { ), ) - replicationSvc := replications.NewService() + replicationSvc := replications.NewService(m.sqlStore, ts) replicationServer := replicationTransport.NewReplicationHandler( m.log.With(zap.String("handler", "replications")), replications.NewLoggingService( diff --git a/remotes/service.go b/remotes/service.go index 0580498f03..6804c65d1a 100644 --- a/remotes/service.go +++ b/remotes/service.go @@ -92,9 +92,6 @@ func (s service) CreateRemoteConnection(ctx context.Context, request influxdb.Cr var rc influxdb.RemoteConnection if err := s.store.DB.GetContext(ctx, &rc, query, args...); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, errRemoteNotFound - } return nil, err } return &rc, nil diff --git a/remotes/service_test.go b/remotes/service_test.go index 29ff46bc26..a3f6c84a9e 100644 --- a/remotes/service_test.go +++ b/remotes/service_test.go @@ -23,7 +23,7 @@ var ( desc = "testing testing" connection = influxdb.RemoteConnection{ ID: initID, - OrgID: platform.ID(10), //createReq.OrgID, + OrgID: platform.ID(10), Name: "test", Description: &desc, RemoteURL: "https://influxdb.cloud", @@ -157,6 +157,23 @@ func TestUpdateAndGetConnection(t *testing.T) { require.NoError(t, svc.ValidateRemoteConnection(ctx, initID)) } +func TestUpdateNoop(t *testing.T) { + t.Parallel() + + svc, _, clean := newTestService(t) + defer clean(t) + + // Create a connection. + created, err := svc.CreateRemoteConnection(ctx, createReq) + require.NoError(t, err) + require.Equal(t, connection, *created) + + // Send a no-op update, assert nothing changed. + updated, err := svc.UpdateRemoteConnection(ctx, initID, influxdb.UpdateRemoteConnectionRequest{}) + require.NoError(t, err) + require.Equal(t, connection, *updated) +} + func TestValidateUpdatedConnectionWithoutPersisting(t *testing.T) { t.Parallel() @@ -315,7 +332,7 @@ func newTestService(t *testing.T) (*service, *remotesMock.MockRemoteConnectionVa mockValidator := remotesMock.NewMockRemoteConnectionValidator(gomock.NewController(t)) svc := service{ store: store, - idGenerator: mock.NewIncrementingIDGenerator(platform.ID(1)), + idGenerator: mock.NewIncrementingIDGenerator(initID), validator: mockValidator, } diff --git a/replication.go b/replication.go index fbcf7655f2..8108f233ec 100644 --- a/replication.go +++ b/replication.go @@ -49,17 +49,17 @@ type ReplicationService interface { // Replication contains all info about a replication that should be returned to users. type Replication struct { - ID platform.ID `json:"id"` - OrgID platform.ID `json:"orgID"` - Name string `json:"name"` - Description *string `json:"description,omitempty"` - RemoteID platform.ID `json:"remoteID"` - LocalBucketID platform.ID `json:"localBucketID"` - RemoteBucketID platform.ID `json:"remoteBucketID"` - MaxQueueSizeBytes int64 `json:"maxQueueSizeBytes"` - CurrentQueueSizeBytes int64 `json:"currentQueueSizeBytes"` - LatestResponseCode *int32 `json:"latestResponseCode,omitempty"` - LatestErrorMessage *string `json:"latestErrorMessage,omitempty"` + ID platform.ID `json:"id" db:"id"` + OrgID platform.ID `json:"orgID" db:"org_id"` + Name string `json:"name" db:"name"` + Description *string `json:"description,omitempty" db:"description"` + RemoteID platform.ID `json:"remoteID" db:"remote_id"` + LocalBucketID platform.ID `json:"localBucketID" db:"local_bucket_id"` + RemoteBucketID platform.ID `json:"remoteBucketID" db:"remote_bucket_id"` + MaxQueueSizeBytes int64 `json:"maxQueueSizeBytes" db:"max_queue_size_bytes"` + CurrentQueueSizeBytes int64 `json:"currentQueueSizeBytes" db:"current_queue_size_bytes"` + LatestResponseCode *int32 `json:"latestResponseCode,omitempty" db:"latest_response_code"` + LatestErrorMessage *string `json:"latestErrorMessage,omitempty" db:"latest_error_message"` } // ReplicationListFilter is a selection filter for listing replications. diff --git a/replications/internal/http_config.go b/replications/internal/http_config.go new file mode 100644 index 0000000000..92e5d86fb2 --- /dev/null +++ b/replications/internal/http_config.go @@ -0,0 +1,13 @@ +package internal + +import "github.com/influxdata/influxdb/v2/kit/platform" + +// ReplicationHTTPConfig contains all info needed by a client to make HTTP requests against the +// remote bucket targeted by a replication. +type ReplicationHTTPConfig struct { + RemoteURL string `db:"remote_url"` + RemoteToken string `db:"remote_api_token"` + RemoteOrgID platform.ID `db:"remote_org_id"` + AllowInsecureTLS bool `db:"allow_insecure_tls"` + RemoteBucketID platform.ID `db:"remote_bucket_id"` +} diff --git a/replications/internal/validator.go b/replications/internal/validator.go new file mode 100644 index 0000000000..e7c2a551f1 --- /dev/null +++ b/replications/internal/validator.go @@ -0,0 +1,17 @@ +package internal + +import ( + "context" + + ierrors "github.com/influxdata/influxdb/v2/kit/platform/errors" +) + +func NewValidator() *stubValidator { + return &stubValidator{} +} + +type stubValidator struct{} + +func (s stubValidator) ValidateReplication(ctx context.Context, config *ReplicationHTTPConfig) error { + return &ierrors.Error{Code: ierrors.ENotImplemented, Msg: "replication validation not implemented"} +} diff --git a/replications/mock/bucket_service.go b/replications/mock/bucket_service.go new file mode 100644 index 0000000000..e87ffdab81 --- /dev/null +++ b/replications/mock/bucket_service.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/influxdata/influxdb/v2/replications (interfaces: BucketService) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + influxdb "github.com/influxdata/influxdb/v2" + platform "github.com/influxdata/influxdb/v2/kit/platform" +) + +// MockBucketService is a mock of BucketService interface. +type MockBucketService struct { + ctrl *gomock.Controller + recorder *MockBucketServiceMockRecorder +} + +// MockBucketServiceMockRecorder is the mock recorder for MockBucketService. +type MockBucketServiceMockRecorder struct { + mock *MockBucketService +} + +// NewMockBucketService creates a new mock instance. +func NewMockBucketService(ctrl *gomock.Controller) *MockBucketService { + mock := &MockBucketService{ctrl: ctrl} + mock.recorder = &MockBucketServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBucketService) EXPECT() *MockBucketServiceMockRecorder { + return m.recorder +} + +// FindBucketByID mocks base method. +func (m *MockBucketService) FindBucketByID(arg0 context.Context, arg1 platform.ID) (*influxdb.Bucket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindBucketByID", arg0, arg1) + ret0, _ := ret[0].(*influxdb.Bucket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindBucketByID indicates an expected call of FindBucketByID. +func (mr *MockBucketServiceMockRecorder) FindBucketByID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindBucketByID", reflect.TypeOf((*MockBucketService)(nil).FindBucketByID), arg0, arg1) +} + +// RLock mocks base method. +func (m *MockBucketService) RLock() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RLock") +} + +// RLock indicates an expected call of RLock. +func (mr *MockBucketServiceMockRecorder) RLock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RLock", reflect.TypeOf((*MockBucketService)(nil).RLock)) +} + +// RUnlock mocks base method. +func (m *MockBucketService) RUnlock() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RUnlock") +} + +// RUnlock indicates an expected call of RUnlock. +func (mr *MockBucketServiceMockRecorder) RUnlock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockBucketService)(nil).RUnlock)) +} diff --git a/replications/mock/validator.go b/replications/mock/validator.go new file mode 100644 index 0000000000..914bbd2c33 --- /dev/null +++ b/replications/mock/validator.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/influxdata/influxdb/v2/replications (interfaces: ReplicationValidator) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + internal "github.com/influxdata/influxdb/v2/replications/internal" +) + +// MockReplicationValidator is a mock of ReplicationValidator interface. +type MockReplicationValidator struct { + ctrl *gomock.Controller + recorder *MockReplicationValidatorMockRecorder +} + +// MockReplicationValidatorMockRecorder is the mock recorder for MockReplicationValidator. +type MockReplicationValidatorMockRecorder struct { + mock *MockReplicationValidator +} + +// NewMockReplicationValidator creates a new mock instance. +func NewMockReplicationValidator(ctrl *gomock.Controller) *MockReplicationValidator { + mock := &MockReplicationValidator{ctrl: ctrl} + mock.recorder = &MockReplicationValidatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReplicationValidator) EXPECT() *MockReplicationValidatorMockRecorder { + return m.recorder +} + +// ValidateReplication mocks base method. +func (m *MockReplicationValidator) ValidateReplication(arg0 context.Context, arg1 *internal.ReplicationHTTPConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateReplication", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateReplication indicates an expected call of ValidateReplication. +func (mr *MockReplicationValidatorMockRecorder) ValidateReplication(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateReplication", reflect.TypeOf((*MockReplicationValidator)(nil).ValidateReplication), arg0, arg1) +} diff --git a/replications/service.go b/replications/service.go index 282bd7e967..770e070a5f 100644 --- a/replications/service.go +++ b/replications/service.go @@ -2,53 +2,318 @@ package replications import ( "context" + "database/sql" + "errors" + "fmt" + sq "github.com/Masterminds/squirrel" "github.com/influxdata/influxdb/v2" "github.com/influxdata/influxdb/v2/kit/platform" ierrors "github.com/influxdata/influxdb/v2/kit/platform/errors" + "github.com/influxdata/influxdb/v2/replications/internal" + "github.com/influxdata/influxdb/v2/snowflake" + "github.com/influxdata/influxdb/v2/sqlite" + "github.com/mattn/go-sqlite3" ) -var errNotImplemented = &ierrors.Error{ - Code: ierrors.ENotImplemented, - Msg: "replication APIs not yet implemented", +var errReplicationNotFound = &ierrors.Error{ + Code: ierrors.ENotFound, + Msg: "replication not found", } -func NewService() *service { - return &service{} +func errRemoteNotFound(id platform.ID, cause error) error { + return &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: fmt.Sprintf("remote %q not found", id), + Err: cause, + } } -type service struct{} +func errLocalBucketNotFound(id platform.ID, cause error) error { + return &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: fmt.Sprintf("local bucket %q not found", id), + Err: cause, + } +} + +func NewService(store *sqlite.SqlStore, bktSvc BucketService) *service { + return &service{ + store: store, + idGenerator: snowflake.NewIDGenerator(), + bucketService: bktSvc, + validator: internal.NewValidator(), + } +} + +type ReplicationValidator interface { + ValidateReplication(context.Context, *internal.ReplicationHTTPConfig) error +} + +type BucketService interface { + RLock() + RUnlock() + FindBucketByID(ctx context.Context, id platform.ID) (*influxdb.Bucket, error) +} + +type service struct { + store *sqlite.SqlStore + idGenerator platform.IDGenerator + bucketService BucketService + validator ReplicationValidator +} var _ influxdb.ReplicationService = (*service)(nil) func (s service) ListReplications(ctx context.Context, filter influxdb.ReplicationListFilter) (*influxdb.Replications, error) { - return nil, errNotImplemented + q := sq.Select( + "id", "org_id", "name", "description", + "remote_id", "local_bucket_id", "remote_bucket_id", + "max_queue_size_bytes", "current_queue_size_bytes", + "latest_response_code", "latest_error_message"). + From("replications"). + Where(sq.Eq{"org_id": filter.OrgID}) + + if filter.Name != nil { + q = q.Where(sq.Eq{"name": *filter.Name}) + } + if filter.RemoteID != nil { + q = q.Where(sq.Eq{"remote_id": *filter.RemoteID}) + } + if filter.LocalBucketID != nil { + q = q.Where(sq.Eq{"local_bucket_id": *filter.LocalBucketID}) + } + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var rs influxdb.Replications + if err := s.store.DB.SelectContext(ctx, &rs.Replications, query, args...); err != nil { + return nil, err + } + return &rs, nil } func (s service) CreateReplication(ctx context.Context, request influxdb.CreateReplicationRequest) (*influxdb.Replication, error) { - return nil, errNotImplemented + s.bucketService.RLock() + defer s.bucketService.RUnlock() + + s.store.Mu.Lock() + defer s.store.Mu.Unlock() + + if _, err := s.bucketService.FindBucketByID(ctx, request.LocalBucketID); err != nil { + return nil, errLocalBucketNotFound(request.LocalBucketID, err) + } + + q := sq.Insert("replications"). + SetMap(sq.Eq{ + "id": s.idGenerator.ID(), + "org_id": request.OrgID, + "name": request.Name, + "description": request.Description, + "remote_id": request.RemoteID, + "local_bucket_id": request.LocalBucketID, + "remote_bucket_id": request.RemoteBucketID, + "max_queue_size_bytes": request.MaxQueueSizeBytes, + "current_queue_size_bytes": 0, + "created_at": "datetime('now')", + "updated_at": "datetime('now')", + }). + Suffix("RETURNING id, org_id, name, description, remote_id, local_bucket_id, remote_bucket_id, max_queue_size_bytes, current_queue_size_bytes") + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var r influxdb.Replication + if err := s.store.DB.GetContext(ctx, &r, query, args...); err != nil { + if sqlErr, ok := err.(sqlite3.Error); ok && sqlErr.ExtendedCode == sqlite3.ErrConstraintForeignKey { + return nil, errRemoteNotFound(request.RemoteID, err) + } + return nil, err + } + return &r, nil } func (s service) ValidateNewReplication(ctx context.Context, request influxdb.CreateReplicationRequest) error { - return errNotImplemented + if _, err := s.bucketService.FindBucketByID(ctx, request.LocalBucketID); err != nil { + return errLocalBucketNotFound(request.LocalBucketID, err) + } + + config := internal.ReplicationHTTPConfig{RemoteBucketID: request.RemoteBucketID} + if err := s.populateRemoteHTTPConfig(ctx, request.RemoteID, &config); err != nil { + return err + } + + if err := s.validator.ValidateReplication(ctx, &config); err != nil { + return &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: "replication parameters fail validation", + Err: err, + } + } + return nil } func (s service) GetReplication(ctx context.Context, id platform.ID) (*influxdb.Replication, error) { - return nil, errNotImplemented + q := sq.Select( + "id", "org_id", "name", "description", + "remote_id", "local_bucket_id", "remote_bucket_id", + "max_queue_size_bytes", "current_queue_size_bytes", + "latest_response_code", "latest_error_message"). + From("replications"). + Where(sq.Eq{"id": id}) + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var r influxdb.Replication + if err := s.store.DB.GetContext(ctx, &r, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errReplicationNotFound + } + return nil, err + } + return &r, nil } func (s service) UpdateReplication(ctx context.Context, id platform.ID, request influxdb.UpdateReplicationRequest) (*influxdb.Replication, error) { - return nil, errNotImplemented + s.store.Mu.Lock() + defer s.store.Mu.Unlock() + + updates := sq.Eq{"updated_at": sq.Expr("datetime('now')")} + if request.Name != nil { + updates["name"] = *request.Name + } + if request.Description != nil { + updates["description"] = *request.Description + } + if request.RemoteID != nil { + updates["remote_id"] = *request.RemoteID + } + if request.RemoteBucketID != nil { + updates["remote_bucket_id"] = *request.RemoteBucketID + } + if request.MaxQueueSizeBytes != nil { + updates["max_queue_size_bytes"] = *request.MaxQueueSizeBytes + } + + q := sq.Update("replications").SetMap(updates).Where(sq.Eq{"id": id}). + Suffix("RETURNING id, org_id, name, description, remote_id, local_bucket_id, remote_bucket_id, max_queue_size_bytes, current_queue_size_bytes") + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var r influxdb.Replication + if err := s.store.DB.GetContext(ctx, &r, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errReplicationNotFound + } + if sqlErr, ok := err.(sqlite3.Error); ok && request.RemoteID != nil && sqlErr.ExtendedCode == sqlite3.ErrConstraintForeignKey { + return nil, errRemoteNotFound(*request.RemoteID, err) + } + return nil, err + } + return &r, nil } func (s service) ValidateUpdatedReplication(ctx context.Context, id platform.ID, request influxdb.UpdateReplicationRequest) error { - return errNotImplemented + baseConfig, err := s.getFullHTTPConfig(ctx, id) + if err != nil { + return err + } + + if request.RemoteID != nil { + if err := s.populateRemoteHTTPConfig(ctx, *request.RemoteID, baseConfig); err != nil { + return err + } + } + + if err := s.validator.ValidateReplication(ctx, baseConfig); err != nil { + return &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: "validation fails after applying update", + Err: err, + } + } + return nil } func (s service) DeleteReplication(ctx context.Context, id platform.ID) error { - return errNotImplemented + s.store.Mu.Lock() + defer s.store.Mu.Unlock() + + q := sq.Delete("replications").Where(sq.Eq{"id": id}).Suffix("RETURNING id") + query, args, err := q.ToSql() + if err != nil { + return err + } + + var d platform.ID + if err := s.store.DB.GetContext(ctx, &d, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errReplicationNotFound + } + return err + } + return nil } func (s service) ValidateReplication(ctx context.Context, id platform.ID) error { - return errNotImplemented + config, err := s.getFullHTTPConfig(ctx, id) + if err != nil { + return err + } + if err := s.validator.ValidateReplication(ctx, config); err != nil { + return &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: "remote failed validation", + Err: err, + } + } + return nil +} + +func (s service) getFullHTTPConfig(ctx context.Context, id platform.ID) (*internal.ReplicationHTTPConfig, error) { + q := sq.Select("c.remote_url", "c.remote_api_token", "c.remote_org_id", "c.allow_insecure_tls", "r.remote_bucket_id"). + From("replications r").InnerJoin("remotes c ON r.remote_id = c.id AND r.id = ?", id) + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var rc internal.ReplicationHTTPConfig + if err := s.store.DB.GetContext(ctx, &rc, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errReplicationNotFound + } + return nil, err + } + return &rc, nil +} + +func (s service) populateRemoteHTTPConfig(ctx context.Context, id platform.ID, target *internal.ReplicationHTTPConfig) error { + q := sq.Select("remote_url", "remote_api_token", "remote_org_id", "allow_insecure_tls"). + From("remotes").Where(sq.Eq{"id": id}) + query, args, err := q.ToSql() + if err != nil { + return err + } + + if err := s.store.DB.GetContext(ctx, target, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errRemoteNotFound(id, nil) + } + return err + } + + return nil } diff --git a/replications/service_test.go b/replications/service_test.go new file mode 100644 index 0000000000..d01a7edead --- /dev/null +++ b/replications/service_test.go @@ -0,0 +1,566 @@ +package replications + +import ( + "context" + "errors" + "fmt" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/golang/mock/gomock" + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/kit/platform" + "github.com/influxdata/influxdb/v2/mock" + "github.com/influxdata/influxdb/v2/replications/internal" + replicationsMock "github.com/influxdata/influxdb/v2/replications/mock" + "github.com/influxdata/influxdb/v2/sqlite" + "github.com/influxdata/influxdb/v2/sqlite/migrations" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +var ( + ctx = context.Background() + initID = platform.ID(1) + desc = "testing testing" + replication = influxdb.Replication{ + ID: initID, + OrgID: platform.ID(10), + Name: "test", + Description: &desc, + RemoteID: platform.ID(100), + LocalBucketID: platform.ID(1000), + RemoteBucketID: platform.ID(99999), + MaxQueueSizeBytes: 3 * influxdb.DefaultReplicationMaxQueueSizeBytes, + } + createReq = influxdb.CreateReplicationRequest{ + OrgID: replication.OrgID, + Name: replication.Name, + Description: replication.Description, + RemoteID: replication.RemoteID, + LocalBucketID: replication.LocalBucketID, + RemoteBucketID: replication.RemoteBucketID, + MaxQueueSizeBytes: replication.MaxQueueSizeBytes, + } + httpConfig = internal.ReplicationHTTPConfig{ + RemoteURL: fmt.Sprintf("http://%s.cloud", replication.RemoteID), + RemoteToken: replication.RemoteID.String(), + RemoteOrgID: platform.ID(888888), + AllowInsecureTLS: true, + RemoteBucketID: replication.RemoteBucketID, + } + newRemoteID = platform.ID(200) + newQueueSize = influxdb.MinReplicationMaxQueueSizeBytes + updateReq = influxdb.UpdateReplicationRequest{ + RemoteID: &newRemoteID, + MaxQueueSizeBytes: &newQueueSize, + } + updatedReplication = influxdb.Replication{ + ID: replication.ID, + OrgID: replication.OrgID, + Name: replication.Name, + Description: replication.Description, + RemoteID: *updateReq.RemoteID, + LocalBucketID: replication.LocalBucketID, + RemoteBucketID: replication.RemoteBucketID, + MaxQueueSizeBytes: *updateReq.MaxQueueSizeBytes, + } + updatedHttpConfig = internal.ReplicationHTTPConfig{ + RemoteURL: fmt.Sprintf("http://%s.cloud", updatedReplication.RemoteID), + RemoteToken: updatedReplication.RemoteID.String(), + RemoteOrgID: platform.ID(888888), + AllowInsecureTLS: true, + RemoteBucketID: updatedReplication.RemoteBucketID, + } +) + +func TestCreateAndGetReplication(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Getting or validating an invalid ID should return an error. + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) + require.Equal(t, errReplicationNotFound, svc.ValidateReplication(ctx, initID)) + + // Create a replication, check the results. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Read the created replication and assert it matches the creation response. + got, err = svc.GetReplication(ctx, initID) + require.NoError(t, err) + require.Equal(t, replication, *got) + + // Validate the replication; this is mostly a no-op for this test, but it allows + // us to check that our sql for extracting the linked remote's parameters is correct. + fakeErr := errors.New("O NO") + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &httpConfig).Return(fakeErr) + require.Contains(t, svc.ValidateReplication(ctx, initID).Error(), fakeErr.Error()) +} + +func TestCreateMissingBucket(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + bucketNotFound := errors.New("bucket not found") + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(nil, bucketNotFound) + + created, err := svc.CreateReplication(ctx, createReq) + require.Equal(t, errLocalBucketNotFound(createReq.LocalBucketID, bucketNotFound), err) + require.Nil(t, created) + + // Make sure nothing was persisted. + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) +} + +func TestCreateMissingRemote(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + created, err := svc.CreateReplication(ctx, createReq) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("remote %q not found", createReq.RemoteID)) + require.Nil(t, created) + + // Make sure nothing was persisted. + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) +} + +func TestValidateReplicationWithoutPersisting(t *testing.T) { + t.Parallel() + + t.Run("missing bucket", func(t *testing.T) { + svc, mocks, clean := newTestService(t) + defer clean(t) + + bucketNotFound := errors.New("bucket not found") + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(nil, bucketNotFound) + + require.Equal(t, errLocalBucketNotFound(createReq.LocalBucketID, bucketNotFound), + svc.ValidateNewReplication(ctx, createReq)) + + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) + }) + + t.Run("missing remote", func(t *testing.T) { + svc, mocks, clean := newTestService(t) + defer clean(t) + + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil) + + require.Contains(t, svc.ValidateNewReplication(ctx, createReq).Error(), + fmt.Sprintf("remote %q not found", createReq.RemoteID)) + + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) + }) + + t.Run("validation error", func(t *testing.T) { + svc, mocks, clean := newTestService(t) + defer clean(t) + + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil) + insertRemote(t, svc.store, createReq.RemoteID) + + fakeErr := errors.New("O NO") + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &httpConfig).Return(fakeErr) + + require.Contains(t, svc.ValidateNewReplication(ctx, createReq).Error(), fakeErr.Error()) + + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) + }) + + t.Run("no error", func(t *testing.T) { + svc, mocks, clean := newTestService(t) + defer clean(t) + + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil) + insertRemote(t, svc.store, createReq.RemoteID) + + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &httpConfig).Return(nil) + + require.NoError(t, svc.ValidateNewReplication(ctx, createReq)) + + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) + }) +} + +func TestUpdateAndGetReplication(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + insertRemote(t, svc.store, updatedReplication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Updating a nonexistent ID fails. + updated, err := svc.UpdateReplication(ctx, initID, updateReq) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, updated) + + // Create a replication. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Update the replication. + updated, err = svc.UpdateReplication(ctx, initID, updateReq) + require.NoError(t, err) + require.Equal(t, updatedReplication, *updated) +} + +func TestUpdateMissingRemote(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Create a replication. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Attempt to update the replication to point at a nonexistent remote. + updated, err := svc.UpdateReplication(ctx, initID, updateReq) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("remote %q not found", *updateReq.RemoteID)) + require.Nil(t, updated) + + // Make sure nothing changed in the DB. + got, err := svc.GetReplication(ctx, initID) + require.NoError(t, err) + require.Equal(t, replication, *got) +} + +func TestUpdateNoop(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Create a replication. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Send a no-op update, assert nothing changed. + updated, err := svc.UpdateReplication(ctx, initID, influxdb.UpdateReplicationRequest{}) + require.NoError(t, err) + require.Equal(t, replication, *updated) +} + +func TestValidateUpdatedReplicationWithoutPersisting(t *testing.T) { + t.Parallel() + + t.Run("bad remote", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Create a replication. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Attempt to update the replication to point at a nonexistent remote. + require.Contains(t, svc.ValidateUpdatedReplication(ctx, initID, updateReq).Error(), + fmt.Sprintf("remote %q not found", *updateReq.RemoteID)) + + // Make sure nothing changed in the DB. + got, err := svc.GetReplication(ctx, initID) + require.NoError(t, err) + require.Equal(t, replication, *got) + }) + + t.Run("validation error", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + insertRemote(t, svc.store, updatedReplication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Create a replication. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Check updating to a failing remote, assert error is returned. + fakeErr := errors.New("O NO") + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &updatedHttpConfig).Return(fakeErr) + + require.Contains(t, svc.ValidateUpdatedReplication(ctx, initID, updateReq).Error(), fakeErr.Error()) + + // Make sure nothing changed in the DB. + got, err := svc.GetReplication(ctx, initID) + require.NoError(t, err) + require.Equal(t, replication, *got) + }) + + t.Run("no error", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + insertRemote(t, svc.store, updatedReplication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Create a replication. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Check updating to a remote that passes validation, assert no error. + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &updatedHttpConfig).Return(nil) + + require.NoError(t, svc.ValidateUpdatedReplication(ctx, initID, updateReq)) + + // Make sure nothing changed in the DB. + got, err := svc.GetReplication(ctx, initID) + require.NoError(t, err) + require.Equal(t, replication, *got) + }) +} + +func TestDeleteReplication(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + + insertRemote(t, svc.store, replication.RemoteID) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). + Return(&influxdb.Bucket{}, nil) + + // Deleting a nonexistent ID should return an error. + require.Equal(t, errReplicationNotFound, svc.DeleteReplication(ctx, initID)) + + // Create a replication, then delete it. + created, err := svc.CreateReplication(ctx, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + require.NoError(t, svc.DeleteReplication(ctx, initID)) + + // Looking up the ID should again produce an error. + got, err := svc.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) +} + +func TestListReplications(t *testing.T) { + t.Parallel() + + createReq2, createReq3 := createReq, createReq + createReq2.Name, createReq3.Name = "test2", "test3" + createReq2.LocalBucketID = platform.ID(77777) + createReq3.RemoteID = updatedReplication.RemoteID + + setup := func(t *testing.T, svc *service, mocks mocks) []influxdb.Replication { + mocks.bucketSvc.EXPECT().RLock().Times(3) + mocks.bucketSvc.EXPECT().RUnlock().Times(3) + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil).Times(2) + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq2.LocalBucketID).Return(&influxdb.Bucket{}, nil) + insertRemote(t, svc.store, createReq.RemoteID) + insertRemote(t, svc.store, createReq3.RemoteID) + + var allReplications []influxdb.Replication + for _, req := range []influxdb.CreateReplicationRequest{createReq, createReq2, createReq3} { + created, err := svc.CreateReplication(ctx, req) + require.NoError(t, err) + allReplications = append(allReplications, *created) + } + return allReplications + } + + t.Run("list all", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + allRepls := setup(t, svc, mocks) + + listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: createReq.OrgID}) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: allRepls}, *listed) + }) + + t.Run("list by name", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + allRepls := setup(t, svc, mocks) + + listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{ + OrgID: createReq.OrgID, + Name: &createReq2.Name, + }) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: allRepls[1:2]}, *listed) + }) + + t.Run("list by remote ID", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + allRepls := setup(t, svc, mocks) + + listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{ + OrgID: createReq.OrgID, + RemoteID: &createReq.RemoteID, + }) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: allRepls[0:2]}, *listed) + }) + + t.Run("list by bucket ID", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + allRepls := setup(t, svc, mocks) + + listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{ + OrgID: createReq.OrgID, + LocalBucketID: &createReq.LocalBucketID, + }) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: append(allRepls[0:1], allRepls[2:]...)}, *listed) + }) + + t.Run("list by other org ID", func(t *testing.T) { + t.Parallel() + + svc, mocks, clean := newTestService(t) + defer clean(t) + setup(t, svc, mocks) + + listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: platform.ID(2)}) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{}, *listed) + }) +} + +type mocks struct { + bucketSvc *replicationsMock.MockBucketService + validator *replicationsMock.MockReplicationValidator +} + +func newTestService(t *testing.T) (*service, mocks, func(t *testing.T)) { + store, clean := sqlite.NewTestStore(t) + logger := zaptest.NewLogger(t) + sqliteMigrator := sqlite.NewMigrator(store, logger) + require.NoError(t, sqliteMigrator.Up(ctx, migrations.All)) + + // Make sure foreign-key checking is enabled. + _, err := store.DB.Exec("PRAGMA foreign_keys = ON;") + require.NoError(t, err) + + ctrl := gomock.NewController(t) + mocks := mocks{ + bucketSvc: replicationsMock.NewMockBucketService(ctrl), + validator: replicationsMock.NewMockReplicationValidator(ctrl), + } + svc := service{ + store: store, + idGenerator: mock.NewIncrementingIDGenerator(initID), + bucketService: mocks.bucketSvc, + validator: mocks.validator, + } + + return &svc, mocks, clean +} + +func insertRemote(t *testing.T, store *sqlite.SqlStore, id platform.ID) { + store.Mu.Lock() + defer store.Mu.Unlock() + + q := sq.Insert("remotes").SetMap(sq.Eq{ + "id": id, + "org_id": replication.OrgID, + "name": fmt.Sprintf("foo-%s", id), + "remote_url": fmt.Sprintf("http://%s.cloud", id), + "remote_api_token": id.String(), + "remote_org_id": platform.ID(888888), + "allow_insecure_tls": true, + "created_at": "datetime('now')", + "updated_at": "datetime('now')", + }) + query, args, err := q.ToSql() + require.NoError(t, err) + + _, err = store.DB.Exec(query, args...) + require.NoError(t, err) +} diff --git a/tenant/service.go b/tenant/service.go index a15b38e1ea..01e9aed043 100644 --- a/tenant/service.go +++ b/tenant/service.go @@ -35,6 +35,14 @@ type Service struct { influxdb.BucketService } +func (s *Service) RLock() { + s.store.RLock() +} + +func (s *Service) RUnlock() { + s.store.RUnlock() +} + // NewService creates a new base tenant service. func NewService(st *Store) *Service { svc := &Service{store: st} diff --git a/tenant/storage.go b/tenant/storage.go index c7895ffdac..371f9d7399 100644 --- a/tenant/storage.go +++ b/tenant/storage.go @@ -48,6 +48,14 @@ func NewStore(kvStore kv.Store, opts ...StoreOption) *Store { return store } +func (s *Store) RLock() { + s.kvStore.RLock() +} + +func (s *Store) RUnlock() { + s.kvStore.RUnlock() +} + // View opens up a transaction that will not write to any data. Implementing interfaces // should take care to ensure that all view transactions do not mutate any data. func (s *Store) View(ctx context.Context, fn func(kv.Tx) error) error {