diff --git a/bolt/sources.go b/bolt/sources.go index 61506fdfda..b9d3e514f8 100644 --- a/bolt/sources.go +++ b/bolt/sources.go @@ -21,14 +21,9 @@ type SourcesStore struct { func (s *SourcesStore) All(ctx context.Context) ([]chronograf.Source, error) { var srcs []chronograf.Source if err := s.client.db.View(func(tx *bolt.Tx) error { - if err := tx.Bucket(SourcesBucket).ForEach(func(k, v []byte) error { - var src chronograf.Source - if err := internal.UnmarshalSource(v, &src); err != nil { - return err - } - srcs = append(srcs, src) - return nil - }); err != nil { + var err error + srcs, err = s.all(ctx, tx) + if err != nil { return err } return nil @@ -51,25 +46,7 @@ func (s *SourcesStore) Add(ctx context.Context, src chronograf.Source) (chronogr } if err := s.client.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(SourcesBucket) - seq, err := b.NextSequence() - if err != nil { - return err - } - src.ID = int(seq) - - if src.Default { - if err := s.resetDefaultSource(b, ctx); err != nil { - return err - } - } - - if v, err := internal.MarshalSource(src); err != nil { - return err - } else if err := b.Put(itob(src.ID), v); err != nil { - return err - } - return nil + return s.add(ctx, &src, tx) }); err != nil { return chronograf.Source{}, err } @@ -79,16 +56,11 @@ func (s *SourcesStore) Add(ctx context.Context, src chronograf.Source) (chronogr // Delete removes the Source from the SourcesStore func (s *SourcesStore) Delete(ctx context.Context, src chronograf.Source) error { - - if err := s.setRandomDefault(ctx, src); err != nil { - return err - } - if err := s.client.db.Update(func(tx *bolt.Tx) error { - if err := tx.Bucket(SourcesBucket).Delete(itob(src.ID)); err != nil { + if err := s.setRandomDefault(ctx, tx, src); err != nil { return err } - return nil + return s.delete(ctx, src, tx) }); err != nil { return err } @@ -100,9 +72,9 @@ func (s *SourcesStore) Delete(ctx context.Context, src chronograf.Source) error func (s *SourcesStore) Get(ctx context.Context, id int) (chronograf.Source, error) { var src chronograf.Source if err := s.client.db.View(func(tx *bolt.Tx) error { - if v := tx.Bucket(SourcesBucket).Get(itob(id)); v == nil { - return chronograf.ErrSourceNotFound - } else if err := internal.UnmarshalSource(v, &src); err != nil { + var err error + src, err = s.get(ctx, id, tx) + if err != nil { return err } return nil @@ -116,24 +88,7 @@ func (s *SourcesStore) Get(ctx context.Context, id int) (chronograf.Source, erro // Update a Source func (s *SourcesStore) Update(ctx context.Context, src chronograf.Source) error { if err := s.client.db.Update(func(tx *bolt.Tx) error { - // Get an existing soource with the same ID. - b := tx.Bucket(SourcesBucket) - if v := b.Get(itob(src.ID)); v == nil { - return chronograf.ErrSourceNotFound - } - - if src.Default { - if err := s.resetDefaultSource(b, ctx); err != nil { - return err - } - } - - if v, err := internal.MarshalSource(src); err != nil { - return err - } else if err := b.Put(itob(src.ID), v); err != nil { - return err - } - return nil + return s.update(ctx, src, tx) }); err != nil { return err } @@ -141,9 +96,85 @@ func (s *SourcesStore) Update(ctx context.Context, src chronograf.Source) error return nil } +func (s *SourcesStore) all(ctx context.Context, tx *bolt.Tx) ([]chronograf.Source, error) { + var srcs []chronograf.Source + if err := tx.Bucket(SourcesBucket).ForEach(func(k, v []byte) error { + var src chronograf.Source + if err := internal.UnmarshalSource(v, &src); err != nil { + return err + } + srcs = append(srcs, src) + return nil + }); err != nil { + return srcs, err + } + return srcs, nil +} + +func (s *SourcesStore) add(ctx context.Context, src *chronograf.Source, tx *bolt.Tx) error { + b := tx.Bucket(SourcesBucket) + seq, err := b.NextSequence() + if err != nil { + return err + } + src.ID = int(seq) + + if src.Default { + if err := s.resetDefaultSource(tx, ctx); err != nil { + return err + } + } + + if v, err := internal.MarshalSource(*src); err != nil { + return err + } else if err := b.Put(itob(src.ID), v); err != nil { + return err + } + return nil +} + +func (s *SourcesStore) delete(ctx context.Context, src chronograf.Source, tx *bolt.Tx) error { + if err := tx.Bucket(SourcesBucket).Delete(itob(src.ID)); err != nil { + return err + } + return nil +} + +func (s *SourcesStore) get(ctx context.Context, id int, tx *bolt.Tx) (chronograf.Source, error) { + var src chronograf.Source + if v := tx.Bucket(SourcesBucket).Get(itob(id)); v == nil { + return src, chronograf.ErrSourceNotFound + } else if err := internal.UnmarshalSource(v, &src); err != nil { + return src, err + } + return src, nil +} + +func (s *SourcesStore) update(ctx context.Context, src chronograf.Source, tx *bolt.Tx) error { + // Get an existing soource with the same ID. + b := tx.Bucket(SourcesBucket) + if v := b.Get(itob(src.ID)); v == nil { + return chronograf.ErrSourceNotFound + } + + if src.Default { + if err := s.resetDefaultSource(tx, ctx); err != nil { + return err + } + } + + if v, err := internal.MarshalSource(src); err != nil { + return err + } else if err := b.Put(itob(src.ID), v); err != nil { + return err + } + return nil +} + // resetDefaultSource unsets the Default flag on all sources -func (s *SourcesStore) resetDefaultSource(b *bolt.Bucket, ctx context.Context) error { - srcs, err := s.All(ctx) +func (s *SourcesStore) resetDefaultSource(tx *bolt.Tx, ctx context.Context) error { + b := tx.Bucket(SourcesBucket) + srcs, err := s.all(ctx, tx) if err != nil { return err } @@ -165,13 +196,13 @@ func (s *SourcesStore) resetDefaultSource(b *bolt.Bucket, ctx context.Context) e // chronograf.Source and set it as the default source. If no other sources are // available, the provided source will be set to the default source if is not // already. It assumes that the provided chronograf.Source has been persisted. -func (s *SourcesStore) setRandomDefault(ctx context.Context, src chronograf.Source) error { +func (s *SourcesStore) setRandomDefault(ctx context.Context, tx *bolt.Tx, src chronograf.Source) error { // Check if requested source is the current default - if target, err := s.Get(ctx, src.ID); err != nil { + if target, err := s.get(ctx, src.ID, tx); err != nil { return err } else if target.Default { // Locate another source to be the new default - if srcs, err := s.All(ctx); err != nil { + if srcs, err := s.all(ctx, tx); err != nil { return err } else { var other *chronograf.Source @@ -185,7 +216,7 @@ func (s *SourcesStore) setRandomDefault(ctx context.Context, src chronograf.Sour // set the other to be the default other.Default = true - if err := s.Update(ctx, *other); err != nil { + if err := s.update(ctx, *other, tx); err != nil { return err } }