From b8e335fbee011c4e03d073ae67d3107f52a70b58 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 8 Nov 2016 12:09:35 -0500 Subject: [PATCH] Prevent deadlocks with source default enforcement Previously, the logic to enforce on default source relied on the public-facing CRUD methods already provided by SourcesStore. This was prone to deadlocks due to the possibility of acquiring a transaction within a transaction. This extracts the logic that was performed within the transactions of each CRUD action and makes the private methods that receive a *bolt.Tx. This allows the convenience methods that enforce default source to use this private API and provide the transaction from its caller. This ensures that there is only ever one transaction acquired by each expored CRUD method. --- bolt/sources.go | 153 +++++++++++++++++++++++++++++------------------- 1 file changed, 92 insertions(+), 61 deletions(-) diff --git a/bolt/sources.go b/bolt/sources.go index 61506fdfda..b9d3e514f8 100644 --- a/bolt/sources.go +++ b/bolt/sources.go @@ -21,14 +21,9 @@ type SourcesStore struct { func (s *SourcesStore) All(ctx context.Context) ([]chronograf.Source, error) { var srcs []chronograf.Source if err := s.client.db.View(func(tx *bolt.Tx) error { - if err := tx.Bucket(SourcesBucket).ForEach(func(k, v []byte) error { - var src chronograf.Source - if err := internal.UnmarshalSource(v, &src); err != nil { - return err - } - srcs = append(srcs, src) - return nil - }); err != nil { + var err error + srcs, err = s.all(ctx, tx) + if err != nil { return err } return nil @@ -51,25 +46,7 @@ func (s *SourcesStore) Add(ctx context.Context, src chronograf.Source) (chronogr } if err := s.client.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(SourcesBucket) - seq, err := b.NextSequence() - if err != nil { - return err - } - src.ID = int(seq) - - if src.Default { - if err := s.resetDefaultSource(b, ctx); err != nil { - return err - } - } - - if v, err := internal.MarshalSource(src); err != nil { - return err - } else if err := b.Put(itob(src.ID), v); err != nil { - return err - } - return nil + return s.add(ctx, &src, tx) }); err != nil { return chronograf.Source{}, err } @@ -79,16 +56,11 @@ func (s *SourcesStore) Add(ctx context.Context, src chronograf.Source) (chronogr // Delete removes the Source from the SourcesStore func (s *SourcesStore) Delete(ctx context.Context, src chronograf.Source) error { - - if err := s.setRandomDefault(ctx, src); err != nil { - return err - } - if err := s.client.db.Update(func(tx *bolt.Tx) error { - if err := tx.Bucket(SourcesBucket).Delete(itob(src.ID)); err != nil { + if err := s.setRandomDefault(ctx, tx, src); err != nil { return err } - return nil + return s.delete(ctx, src, tx) }); err != nil { return err } @@ -100,9 +72,9 @@ func (s *SourcesStore) Delete(ctx context.Context, src chronograf.Source) error func (s *SourcesStore) Get(ctx context.Context, id int) (chronograf.Source, error) { var src chronograf.Source if err := s.client.db.View(func(tx *bolt.Tx) error { - if v := tx.Bucket(SourcesBucket).Get(itob(id)); v == nil { - return chronograf.ErrSourceNotFound - } else if err := internal.UnmarshalSource(v, &src); err != nil { + var err error + src, err = s.get(ctx, id, tx) + if err != nil { return err } return nil @@ -116,24 +88,7 @@ func (s *SourcesStore) Get(ctx context.Context, id int) (chronograf.Source, erro // Update a Source func (s *SourcesStore) Update(ctx context.Context, src chronograf.Source) error { if err := s.client.db.Update(func(tx *bolt.Tx) error { - // Get an existing soource with the same ID. - b := tx.Bucket(SourcesBucket) - if v := b.Get(itob(src.ID)); v == nil { - return chronograf.ErrSourceNotFound - } - - if src.Default { - if err := s.resetDefaultSource(b, ctx); err != nil { - return err - } - } - - if v, err := internal.MarshalSource(src); err != nil { - return err - } else if err := b.Put(itob(src.ID), v); err != nil { - return err - } - return nil + return s.update(ctx, src, tx) }); err != nil { return err } @@ -141,9 +96,85 @@ func (s *SourcesStore) Update(ctx context.Context, src chronograf.Source) error return nil } +func (s *SourcesStore) all(ctx context.Context, tx *bolt.Tx) ([]chronograf.Source, error) { + var srcs []chronograf.Source + if err := tx.Bucket(SourcesBucket).ForEach(func(k, v []byte) error { + var src chronograf.Source + if err := internal.UnmarshalSource(v, &src); err != nil { + return err + } + srcs = append(srcs, src) + return nil + }); err != nil { + return srcs, err + } + return srcs, nil +} + +func (s *SourcesStore) add(ctx context.Context, src *chronograf.Source, tx *bolt.Tx) error { + b := tx.Bucket(SourcesBucket) + seq, err := b.NextSequence() + if err != nil { + return err + } + src.ID = int(seq) + + if src.Default { + if err := s.resetDefaultSource(tx, ctx); err != nil { + return err + } + } + + if v, err := internal.MarshalSource(*src); err != nil { + return err + } else if err := b.Put(itob(src.ID), v); err != nil { + return err + } + return nil +} + +func (s *SourcesStore) delete(ctx context.Context, src chronograf.Source, tx *bolt.Tx) error { + if err := tx.Bucket(SourcesBucket).Delete(itob(src.ID)); err != nil { + return err + } + return nil +} + +func (s *SourcesStore) get(ctx context.Context, id int, tx *bolt.Tx) (chronograf.Source, error) { + var src chronograf.Source + if v := tx.Bucket(SourcesBucket).Get(itob(id)); v == nil { + return src, chronograf.ErrSourceNotFound + } else if err := internal.UnmarshalSource(v, &src); err != nil { + return src, err + } + return src, nil +} + +func (s *SourcesStore) update(ctx context.Context, src chronograf.Source, tx *bolt.Tx) error { + // Get an existing soource with the same ID. + b := tx.Bucket(SourcesBucket) + if v := b.Get(itob(src.ID)); v == nil { + return chronograf.ErrSourceNotFound + } + + if src.Default { + if err := s.resetDefaultSource(tx, ctx); err != nil { + return err + } + } + + if v, err := internal.MarshalSource(src); err != nil { + return err + } else if err := b.Put(itob(src.ID), v); err != nil { + return err + } + return nil +} + // resetDefaultSource unsets the Default flag on all sources -func (s *SourcesStore) resetDefaultSource(b *bolt.Bucket, ctx context.Context) error { - srcs, err := s.All(ctx) +func (s *SourcesStore) resetDefaultSource(tx *bolt.Tx, ctx context.Context) error { + b := tx.Bucket(SourcesBucket) + srcs, err := s.all(ctx, tx) if err != nil { return err } @@ -165,13 +196,13 @@ func (s *SourcesStore) resetDefaultSource(b *bolt.Bucket, ctx context.Context) e // chronograf.Source and set it as the default source. If no other sources are // available, the provided source will be set to the default source if is not // already. It assumes that the provided chronograf.Source has been persisted. -func (s *SourcesStore) setRandomDefault(ctx context.Context, src chronograf.Source) error { +func (s *SourcesStore) setRandomDefault(ctx context.Context, tx *bolt.Tx, src chronograf.Source) error { // Check if requested source is the current default - if target, err := s.Get(ctx, src.ID); err != nil { + if target, err := s.get(ctx, src.ID, tx); err != nil { return err } else if target.Default { // Locate another source to be the new default - if srcs, err := s.All(ctx); err != nil { + if srcs, err := s.all(ctx, tx); err != nil { return err } else { var other *chronograf.Source @@ -185,7 +216,7 @@ func (s *SourcesStore) setRandomDefault(ctx context.Context, src chronograf.Sour // set the other to be the default other.Default = true - if err := s.Update(ctx, *other); err != nil { + if err := s.update(ctx, *other, tx); err != nil { return err } }