diff --git a/cmd/influxd/run/server.go b/cmd/influxd/run/server.go index d43a19872c..abde5dc572 100644 --- a/cmd/influxd/run/server.go +++ b/cmd/influxd/run/server.go @@ -321,11 +321,7 @@ func (s *Server) appendPrecreatorService(c precreator.Config) error { if !c.Enabled { return nil } - srv, err := precreator.NewService(c) - if err != nil { - return err - } - + srv := precreator.NewService(c) srv.MetaClient = s.MetaClient s.Services = append(s.Services, srv) return nil diff --git a/internal/meta_client.go b/internal/meta_client.go index 916d8568e6..7234627caa 100644 --- a/internal/meta_client.go +++ b/internal/meta_client.go @@ -32,7 +32,8 @@ type MetaClientMock struct { OpenFn func() error - PruneShardGroupsFn func() error + PrecreateShardGroupsFn func(from, to time.Time) error + PruneShardGroupsFn func() error RetentionPolicyFn func(database, name string) (rpi *meta.RetentionPolicyInfo, err error) @@ -167,4 +168,7 @@ func (c *MetaClientMock) Open() error { return c.OpenFn() } func (c *MetaClientMock) Data() meta.Data { return c.DataFn() } func (c *MetaClientMock) SetData(d *meta.Data) error { return c.SetDataFn(d) } +func (c *MetaClientMock) PrecreateShardGroups(from, to time.Time) error { + return c.PrecreateShardGroupsFn(from, to) +} func (c *MetaClientMock) PruneShardGroups() error { return c.PruneShardGroupsFn() } diff --git a/monitor/build_info_test.go b/monitor/build_info_test.go new file mode 100644 index 0000000000..851ed3b13d --- /dev/null +++ b/monitor/build_info_test.go @@ -0,0 +1,43 @@ +package monitor_test + +import ( + "reflect" + "testing" + + "github.com/influxdata/influxdb/monitor" +) + +func TestDiagnostics_BuildInfo(t *testing.T) { + s := monitor.New(nil, monitor.Config{}) + s.Version = "1.2.0" + s.Commit = "b7bb7e8359642b6e071735b50ae41f5eb343fd42" + s.Branch = "1.2" + s.BuildTime = "10m30s" + + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer s.Close() + + d, err := s.Diagnostics() + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + + diags, ok := d["build"] + if !ok { + t.Error("no diagnostics found for 'build'") + return + } + + if got, exp := diags.Columns, []string{"Branch", "Build Time", "Commit", "Version"}; !reflect.DeepEqual(got, exp) { + t.Errorf("unexpected columns: got=%v exp=%v", got, exp) + } + + if got, exp := diags.Rows, [][]interface{}{ + []interface{}{"1.2", "10m30s", "b7bb7e8359642b6e071735b50ae41f5eb343fd42", "1.2.0"}, + }; !reflect.DeepEqual(got, exp) { + t.Errorf("unexpected rows: got=%v exp=%v", got, exp) + } +} diff --git a/monitor/go_runtime_test.go b/monitor/go_runtime_test.go new file mode 100644 index 0000000000..dc52b66238 --- /dev/null +++ b/monitor/go_runtime_test.go @@ -0,0 +1,39 @@ +package monitor_test + +import ( + "reflect" + "runtime" + "testing" + + "github.com/influxdata/influxdb/monitor" +) + +func TestDiagnostics_GoRuntime(t *testing.T) { + s := monitor.New(nil, monitor.Config{}) + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer s.Close() + + d, err := s.Diagnostics() + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + + diags, ok := d["runtime"] + if !ok { + t.Error("no diagnostics found for 'runtime'") + return + } + + if got, exp := diags.Columns, []string{"GOARCH", "GOMAXPROCS", "GOOS", "version"}; !reflect.DeepEqual(got, exp) { + t.Errorf("unexpected columns: got=%v exp=%v", got, exp) + } + + if got, exp := diags.Rows, [][]interface{}{ + []interface{}{runtime.GOARCH, runtime.GOMAXPROCS(-1), runtime.GOOS, runtime.Version()}, + }; !reflect.DeepEqual(got, exp) { + t.Errorf("unexpected rows: got=%v exp=%v", got, exp) + } +} diff --git a/monitor/network_test.go b/monitor/network_test.go new file mode 100644 index 0000000000..0615e0a11c --- /dev/null +++ b/monitor/network_test.go @@ -0,0 +1,44 @@ +package monitor_test + +import ( + "os" + "reflect" + "testing" + + "github.com/influxdata/influxdb/monitor" +) + +func TestDiagnostics_Network(t *testing.T) { + hostname, err := os.Hostname() + if err != nil { + t.Fatalf("unexpected error retrieving hostname: %s", err) + } + + s := monitor.New(nil, monitor.Config{}) + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer s.Close() + + d, err := s.Diagnostics() + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + + diags, ok := d["network"] + if !ok { + t.Error("no diagnostics found for 'network'") + return + } + + if got, exp := diags.Columns, []string{"hostname"}; !reflect.DeepEqual(got, exp) { + t.Errorf("unexpected columns: got=%v exp=%v", got, exp) + } + + if got, exp := diags.Rows, [][]interface{}{ + []interface{}{hostname}, + }; !reflect.DeepEqual(got, exp) { + t.Errorf("unexpected rows: got=%v exp=%v", got, exp) + } +} diff --git a/monitor/service.go b/monitor/service.go index 82ee8d21da..11e5bfcac3 100644 --- a/monitor/service.go +++ b/monitor/service.go @@ -220,7 +220,7 @@ func (m *Monitor) RegisterDiagnosticsClient(name string, client diagnostics.Clie m.mu.Lock() defer m.mu.Unlock() m.diagRegistrations[name] = client - m.Logger.Info(fmt.Sprintf(`'%s' registered for diagnostics monitoring`, name)) + m.Logger.Info("registered for diagnostics monitoring", zap.String("name", name)) } // DeregisterDiagnosticsClient deregisters a diagnostics client by name. @@ -250,8 +250,11 @@ func (m *Monitor) Statistics(tags map[string]string) ([]*Statistic, error) { statistic.Tags[k] = v } - // Every other top-level expvar value is a map. - m := kv.Value.(*expvar.Map) + // Every other top-level expvar value should be a map. + m, ok := kv.Value.(*expvar.Map) + if !ok { + return + } m.Do(func(subKV expvar.KeyValue) { switch subKV.Key { @@ -344,8 +347,10 @@ func (m *Monitor) gatherStatistics(statistics []*Statistic, tags map[string]stri m.mu.RLock() defer m.mu.RUnlock() - for _, s := range m.reporter.Statistics(tags) { - statistics = append(statistics, &Statistic{Statistic: s}) + if m.reporter != nil { + for _, s := range m.reporter.Statistics(tags) { + statistics = append(statistics, &Statistic{Statistic: s}) + } } return statistics } diff --git a/monitor/service_test.go b/monitor/service_test.go new file mode 100644 index 0000000000..58367532fa --- /dev/null +++ b/monitor/service_test.go @@ -0,0 +1,478 @@ +package monitor_test + +import ( + "bytes" + "expvar" + "fmt" + "os" + "reflect" + "sort" + "testing" + "time" + + "github.com/influxdata/influxdb/models" + "github.com/influxdata/influxdb/monitor" + "github.com/influxdata/influxdb/services/meta" + "github.com/influxdata/influxdb/toml" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestMonitor_Open(t *testing.T) { + s := monitor.New(nil, monitor.Config{}) + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + + // Verify that opening twice is fine. + if err := s.Open(); err != nil { + s.Close() + t.Fatalf("unexpected error on second open: %s", err) + } + + if err := s.Close(); err != nil { + t.Fatalf("unexpected close error: %s", err) + } + + // Verify that closing twice is fine. + if err := s.Close(); err != nil { + t.Fatalf("unexpected error on second close: %s", err) + } +} + +func TestMonitor_SetPointsWriter_StoreEnabled(t *testing.T) { + var mc MetaClient + mc.CreateDatabaseWithRetentionPolicyFn = func(name string, spec *meta.RetentionPolicySpec) (*meta.DatabaseInfo, error) { + return &meta.DatabaseInfo{Name: name}, nil + } + + config := monitor.NewConfig() + s := monitor.New(nil, config) + s.MetaClient = &mc + core, logs := observer.New(zap.DebugLevel) + s.WithLogger(zap.New(core)) + + // Setting the points writer should open the monitor. + var pw PointsWriter + if err := s.SetPointsWriter(&pw); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + // Verify that the monitor was opened by looking at the log messages. + if logs.FilterMessage("Starting monitor system").Len() == 0 { + t.Errorf("monitor system was never started") + } +} + +func TestMonitor_SetPointsWriter_StoreDisabled(t *testing.T) { + s := monitor.New(nil, monitor.Config{}) + core, logs := observer.New(zap.DebugLevel) + s.WithLogger(zap.New(core)) + + // Setting the points writer should open the monitor. + var pw PointsWriter + if err := s.SetPointsWriter(&pw); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + // Verify that the monitor was not opened by looking at the log messages. + if logs.FilterMessage("Starting monitor system").Len() > 0 { + t.Errorf("monitor system should not have been started") + } +} + +func TestMonitor_StoreStatistics(t *testing.T) { + done := make(chan struct{}) + defer close(done) + ch := make(chan models.Points) + + var mc MetaClient + mc.CreateDatabaseWithRetentionPolicyFn = func(name string, spec *meta.RetentionPolicySpec) (*meta.DatabaseInfo, error) { + if got, want := name, monitor.DefaultStoreDatabase; got != want { + t.Errorf("unexpected database: got=%q want=%q", got, want) + } + if got, want := spec.Name, monitor.MonitorRetentionPolicy; got != want { + t.Errorf("unexpected retention policy: got=%q want=%q", got, want) + } + if spec.Duration != nil { + if got, want := *spec.Duration, monitor.MonitorRetentionPolicyDuration; got != want { + t.Errorf("unexpected duration: got=%q want=%q", got, want) + } + } else { + t.Error("expected duration in retention policy spec") + } + if spec.ReplicaN != nil { + if got, want := *spec.ReplicaN, monitor.MonitorRetentionPolicyReplicaN; got != want { + t.Errorf("unexpected replica number: got=%q want=%q", got, want) + } + } else { + t.Error("expected replica number in retention policy spec") + } + return &meta.DatabaseInfo{Name: name}, nil + } + + var pw PointsWriter + pw.WritePointsFn = func(database, policy string, points models.Points) error { + // Verify that we are attempting to write to the correct database. + if got, want := database, monitor.DefaultStoreDatabase; got != want { + t.Errorf("unexpected database: got=%q want=%q", got, want) + } + if got, want := policy, monitor.MonitorRetentionPolicy; got != want { + t.Errorf("unexpected retention policy: got=%q want=%q", got, want) + } + + // Attempt to write the points to the main goroutine. + select { + case <-done: + case ch <- points: + } + return nil + } + + config := monitor.NewConfig() + config.StoreInterval = toml.Duration(10 * time.Millisecond) + s := monitor.New(nil, config) + s.MetaClient = &mc + s.PointsWriter = &pw + + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer s.Close() + + timer := time.NewTimer(100 * time.Millisecond) + select { + case points := <-ch: + timer.Stop() + + // Search for the runtime statistic. + found := false + for _, pt := range points { + if !bytes.Equal(pt.Name(), []byte("runtime")) { + continue + } + + // There should be a hostname. + if got := pt.Tags().GetString("hostname"); len(got) == 0 { + t.Errorf("expected hostname tag") + } + // This should write on an exact interval of 10 milliseconds. + if got, want := pt.Time(), pt.Time().Truncate(10*time.Millisecond); got != want { + t.Errorf("unexpected time: got=%q want=%q", got, want) + } + found = true + break + } + + if !found { + t.Error("unable to find runtime statistic") + } + case <-timer.C: + t.Errorf("timeout while waiting for statistics to be written") + } +} + +func TestMonitor_Reporter(t *testing.T) { + reporter := ReporterFunc(func(tags map[string]string) []models.Statistic { + return []models.Statistic{ + { + Name: "foo", + Tags: tags, + Values: map[string]interface{}{ + "value": "bar", + }, + }, + } + }) + + done := make(chan struct{}) + defer close(done) + ch := make(chan models.Points) + + var mc MetaClient + mc.CreateDatabaseWithRetentionPolicyFn = func(name string, spec *meta.RetentionPolicySpec) (*meta.DatabaseInfo, error) { + return &meta.DatabaseInfo{Name: name}, nil + } + + var pw PointsWriter + pw.WritePointsFn = func(database, policy string, points models.Points) error { + // Attempt to write the points to the main goroutine. + select { + case <-done: + case ch <- points: + } + return nil + } + + config := monitor.NewConfig() + config.StoreInterval = toml.Duration(10 * time.Millisecond) + s := monitor.New(reporter, config) + s.MetaClient = &mc + s.PointsWriter = &pw + + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer s.Close() + + timer := time.NewTimer(100 * time.Millisecond) + select { + case points := <-ch: + timer.Stop() + + // Look for the statistic. + found := false + for _, pt := range points { + if !bytes.Equal(pt.Name(), []byte("foo")) { + continue + } + found = true + break + } + + if !found { + t.Error("unable to find foo statistic") + } + case <-timer.C: + t.Errorf("timeout while waiting for statistics to be written") + } +} + +func expvarMap(name string, tags map[string]string, fields map[string]interface{}) *expvar.Map { + m := new(expvar.Map).Init() + eName := new(expvar.String) + eName.Set(name) + m.Set("name", eName) + + var eTags *expvar.Map + if len(tags) > 0 { + eTags = new(expvar.Map).Init() + for k, v := range tags { + kv := new(expvar.String) + kv.Set(v) + eTags.Set(k, kv) + } + m.Set("tags", eTags) + } + + var eFields *expvar.Map + if len(fields) > 0 { + eFields = new(expvar.Map).Init() + for k, v := range fields { + switch v := v.(type) { + case float64: + kv := new(expvar.Float) + kv.Set(v) + eFields.Set(k, kv) + case int: + kv := new(expvar.Int) + kv.Set(int64(v)) + eFields.Set(k, kv) + case string: + kv := new(expvar.String) + kv.Set(v) + eFields.Set(k, kv) + } + } + m.Set("values", eFields) + } + return m +} + +func TestMonitor_Expvar(t *testing.T) { + done := make(chan struct{}) + defer close(done) + ch := make(chan models.Points) + + var mc MetaClient + mc.CreateDatabaseWithRetentionPolicyFn = func(name string, spec *meta.RetentionPolicySpec) (*meta.DatabaseInfo, error) { + return &meta.DatabaseInfo{Name: name}, nil + } + + var pw PointsWriter + pw.WritePointsFn = func(database, policy string, points models.Points) error { + // Attempt to write the points to the main goroutine. + select { + case <-done: + case ch <- points: + } + return nil + } + + config := monitor.NewConfig() + config.StoreInterval = toml.Duration(10 * time.Millisecond) + s := monitor.New(nil, config) + s.MetaClient = &mc + s.PointsWriter = &pw + + expvar.Publish("expvar1", expvarMap( + "expvar1", + map[string]string{ + "region": "uswest2", + }, + map[string]interface{}{ + "value": 2.0, + }, + )) + expvar.Publish("expvar2", expvarMap( + "expvar2", + map[string]string{ + "region": "uswest2", + }, + nil, + )) + expvar.Publish("expvar3", expvarMap( + "expvar3", + nil, + map[string]interface{}{ + "value": 2, + }, + )) + + bad := new(expvar.String) + bad.Set("badentry") + expvar.Publish("expvar4", bad) + + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer s.Close() + + hostname, _ := os.Hostname() + timer := time.NewTimer(100 * time.Millisecond) + select { + case points := <-ch: + timer.Stop() + + // Look for the statistic. + var found1, found3 bool + for _, pt := range points { + if bytes.Equal(pt.Name(), []byte("expvar1")) { + if got, want := pt.Tags().HashKey(), []byte(fmt.Sprintf(",hostname=%s,region=uswest2", hostname)); !reflect.DeepEqual(got, want) { + t.Errorf("unexpected expvar1 tags: got=%v want=%v", string(got), string(want)) + } + fields, _ := pt.Fields() + if got, want := fields, models.Fields(map[string]interface{}{ + "value": 2.0, + }); !reflect.DeepEqual(got, want) { + t.Errorf("unexpected expvar1 fields: got=%v want=%v", got, want) + } + found1 = true + } else if bytes.Equal(pt.Name(), []byte("expvar2")) { + t.Error("found expvar2 statistic") + } else if bytes.Equal(pt.Name(), []byte("expvar3")) { + if got, want := pt.Tags().HashKey(), []byte(fmt.Sprintf(",hostname=%s", hostname)); !reflect.DeepEqual(got, want) { + t.Errorf("unexpected expvar3 tags: got=%v want=%v", string(got), string(want)) + } + fields, _ := pt.Fields() + if got, want := fields, models.Fields(map[string]interface{}{ + "value": int64(2), + }); !reflect.DeepEqual(got, want) { + t.Errorf("unexpected expvar3 fields: got=%v want=%v", got, want) + } + found3 = true + } + } + + if !found1 { + t.Error("unable to find expvar1 statistic") + } + if !found3 { + t.Error("unable to find expvar3 statistic") + } + case <-timer.C: + t.Errorf("timeout while waiting for statistics to be written") + } +} + +func TestMonitor_QuickClose(t *testing.T) { + var mc MetaClient + mc.CreateDatabaseWithRetentionPolicyFn = func(name string, spec *meta.RetentionPolicySpec) (*meta.DatabaseInfo, error) { + return &meta.DatabaseInfo{Name: name}, nil + } + + var pw PointsWriter + config := monitor.NewConfig() + config.StoreInterval = toml.Duration(24 * time.Hour) + s := monitor.New(nil, config) + s.MetaClient = &mc + s.PointsWriter = &pw + + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if err := s.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestStatistic_ValueNames(t *testing.T) { + statistic := monitor.Statistic{ + Statistic: models.Statistic{ + Name: "foo", + Values: map[string]interface{}{ + "abc": 1.0, + "def": 2.0, + }, + }, + } + + names := statistic.ValueNames() + if got, want := names, []string{"abc", "def"}; !reflect.DeepEqual(got, want) { + t.Errorf("unexpected value names: got=%v want=%v", got, want) + } +} + +func TestStatistics_Sort(t *testing.T) { + statistics := []*monitor.Statistic{ + {Statistic: models.Statistic{Name: "b"}}, + {Statistic: models.Statistic{Name: "a"}}, + {Statistic: models.Statistic{Name: "c"}}, + } + + sort.Sort(monitor.Statistics(statistics)) + names := make([]string, 0, len(statistics)) + for _, stat := range statistics { + names = append(names, stat.Name) + } + + if got, want := names, []string{"a", "b", "c"}; !reflect.DeepEqual(got, want) { + t.Errorf("incorrect sorting of statistics: got=%v want=%v", got, want) + } +} + +type ReporterFunc func(tags map[string]string) []models.Statistic + +func (f ReporterFunc) Statistics(tags map[string]string) []models.Statistic { + return f(tags) +} + +type PointsWriter struct { + WritePointsFn func(database, policy string, points models.Points) error +} + +func (pw *PointsWriter) WritePoints(database, policy string, points models.Points) error { + if pw.WritePointsFn != nil { + return pw.WritePointsFn(database, policy, points) + } + return nil +} + +type MetaClient struct { + CreateDatabaseWithRetentionPolicyFn func(name string, spec *meta.RetentionPolicySpec) (*meta.DatabaseInfo, error) + DatabaseFn func(name string) *meta.DatabaseInfo +} + +func (m *MetaClient) CreateDatabaseWithRetentionPolicy(name string, spec *meta.RetentionPolicySpec) (*meta.DatabaseInfo, error) { + return m.CreateDatabaseWithRetentionPolicyFn(name, spec) +} + +func (m *MetaClient) Database(name string) *meta.DatabaseInfo { + if m.DatabaseFn != nil { + return m.DatabaseFn(name) + } + return nil +} diff --git a/monitor/system.go b/monitor/system.go index bbeab8a4f8..01a6bc5016 100644 --- a/monitor/system.go +++ b/monitor/system.go @@ -17,11 +17,12 @@ func init() { type system struct{} func (s *system) Diagnostics() (*diagnostics.Diagnostics, error) { + currentTime := time.Now().UTC() d := map[string]interface{}{ "PID": os.Getpid(), - "currentTime": time.Now().UTC(), + "currentTime": currentTime, "started": startTime, - "uptime": time.Since(startTime).String(), + "uptime": currentTime.Sub(startTime).String(), } return diagnostics.RowFromMap(d), nil diff --git a/monitor/system_test.go b/monitor/system_test.go new file mode 100644 index 0000000000..923345b896 --- /dev/null +++ b/monitor/system_test.go @@ -0,0 +1,55 @@ +package monitor_test + +import ( + "os" + "reflect" + "testing" + "time" + + "github.com/influxdata/influxdb/monitor" +) + +func TestDiagnostics_System(t *testing.T) { + s := monitor.New(nil, monitor.Config{}) + if err := s.Open(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer s.Close() + + d, err := s.Diagnostics() + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + + diags, ok := d["system"] + if !ok { + t.Fatal("no diagnostics found for 'system'") + } + + if got, exp := diags.Columns, []string{"PID", "currentTime", "started", "uptime"}; !reflect.DeepEqual(got, exp) { + t.Errorf("unexpected columns: got=%v exp=%v", got, exp) + } + + // So this next part is nearly impossible to match, so just check if they look correct. + if exp, got := 1, len(diags.Rows); exp != got { + t.Fatalf("expected exactly %d row, got %d", exp, got) + } + + if got, exp := diags.Rows[0][0].(int), os.Getpid(); got != exp { + t.Errorf("unexpected pid: got=%v exp=%v", got, exp) + } + + currentTime := diags.Rows[0][1].(time.Time) + startTime := diags.Rows[0][2].(time.Time) + if !startTime.Before(currentTime) { + t.Errorf("start time is not before the current time: %s (start), %s (current)", startTime, currentTime) + } + + uptime, err := time.ParseDuration(diags.Rows[0][3].(string)) + if err != nil { + t.Errorf("unable to parse uptime duration: %s: %s", diags.Rows[0][3], err) + } else if got, exp := uptime, currentTime.Sub(startTime); got != exp { + t.Errorf("uptime does not match the difference between start time and current time: got=%v exp=%v", got, exp) + } +} diff --git a/services/precreator/config_test.go b/services/precreator/config_test.go index da87fcf53f..a7427a328f 100644 --- a/services/precreator/config_test.go +++ b/services/precreator/config_test.go @@ -59,4 +59,9 @@ func TestConfig_Validate(t *testing.T) { if err := c.Validate(); err == nil { t.Fatal("expected error for negative advance-period, got nil") } + + c.Enabled = false + if err := c.Validate(); err != nil { + t.Fatalf("unexpected validation fail from disabled config: %s", err) + } } diff --git a/services/precreator/service.go b/services/precreator/service.go index a90192b595..b7382a365c 100644 --- a/services/precreator/service.go +++ b/services/precreator/service.go @@ -25,14 +25,12 @@ type Service struct { } // NewService returns an instance of the precreation service. -func NewService(c Config) (*Service, error) { - s := Service{ +func NewService(c Config) *Service { + return &Service{ checkInterval: time.Duration(c.CheckInterval), advancePeriod: time.Duration(c.AdvancePeriod), Logger: zap.NewNop(), } - - return &s, nil } // WithLogger sets the logger for the service. diff --git a/services/precreator/service_test.go b/services/precreator/service_test.go index bb2d20ddf9..8005ef812b 100644 --- a/services/precreator/service_test.go +++ b/services/precreator/service_test.go @@ -1,55 +1,55 @@ -package precreator +package precreator_test import ( - "sync" + "os" "testing" "time" + "github.com/influxdata/influxdb/internal" + "github.com/influxdata/influxdb/logger" + "github.com/influxdata/influxdb/services/precreator" "github.com/influxdata/influxdb/toml" ) -func Test_ShardPrecreation(t *testing.T) { - t.Parallel() +func TestShardPrecreation(t *testing.T) { + done := make(chan struct{}) + precreate := false - now := time.Now().UTC() - advancePeriod := 5 * time.Minute - - // A test metastaore which returns 2 shard groups, only 1 of which requires a successor. - var wg sync.WaitGroup - wg.Add(1) - ms := metaClient{ - PrecreateShardGroupsFn: func(v, u time.Time) error { - wg.Done() - if u != now.Add(advancePeriod) { - t.Fatalf("precreation called with wrong time, got %s, exp %s", u, now) - } - return nil - }, + var mc internal.MetaClientMock + mc.PrecreateShardGroupsFn = func(now, cutoff time.Time) error { + if !precreate { + close(done) + precreate = true + } + return nil } - srv, err := NewService(Config{ - CheckInterval: toml.Duration(time.Minute), - AdvancePeriod: toml.Duration(advancePeriod), - }) - if err != nil { - t.Fatalf("failed to create shard precreation service: %s", err.Error()) - } - srv.MetaClient = ms + s := NewTestService() + s.MetaClient = &mc - err = srv.precreate(now) - if err != nil { - t.Fatalf("failed to precreate shards: %s", err.Error()) + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() // double close should not cause a panic + + timer := time.NewTimer(100 * time.Millisecond) + select { + case <-done: + timer.Stop() + case <-timer.C: + t.Errorf("timeout exceeded while waiting for precreate") } - wg.Wait() // Ensure metaClient test function is called. - return + if err := s.Close(); err != nil { + t.Fatalf("unexpected close error: %s", err) + } } -// PointsWriter represents a mock impl of PointsWriter. -type metaClient struct { - PrecreateShardGroupsFn func(now, cutoff time.Time) error -} +func NewTestService() *precreator.Service { + config := precreator.NewConfig() + config.CheckInterval = toml.Duration(10 * time.Millisecond) -func (m metaClient) PrecreateShardGroups(now, cutoff time.Time) error { - return m.PrecreateShardGroupsFn(now, cutoff) + s := precreator.NewService(config) + s.WithLogger(logger.New(os.Stderr)) + return s } diff --git a/services/retention/config_test.go b/services/retention/config_test.go index e373d1e06d..895295cfa8 100644 --- a/services/retention/config_test.go +++ b/services/retention/config_test.go @@ -43,4 +43,9 @@ func TestConfig_Validate(t *testing.T) { if err := c.Validate(); err == nil { t.Fatal("expected error for negative check-interval, got nil") } + + c.Enabled = false + if err := c.Validate(); err != nil { + t.Fatalf("unexpected validation fail from disabled config: %s", err) + } } diff --git a/services/retention/service.go b/services/retention/service.go index 5323fa6a29..b7483319ca 100644 --- a/services/retention/service.go +++ b/services/retention/service.go @@ -43,7 +43,7 @@ func (s *Service) Open() error { return nil } - s.logger.Info(fmt.Sprint("Starting retention policy enforcement service with check interval of ", s.config.CheckInterval)) + s.logger.Info("Starting retention policy enforcement service", zap.String("check-interval", s.config.CheckInterval.String())) s.done = make(chan struct{}) s.wg.Add(1) diff --git a/services/retention/service_test.go b/services/retention/service_test.go index 2f066e6de5..d5d5dad7ca 100644 --- a/services/retention/service_test.go +++ b/services/retention/service_test.go @@ -57,6 +57,154 @@ func TestService_OpenClose(t *testing.T) { } } +func TestService_CheckShards(t *testing.T) { + now := time.Now() + // Account for any time difference that could cause some of the logic in + // this test to fail due to a race condition. If we are at the very end of + // the hour, we can choose a time interval based on one "now" time and then + // run the retention service in the next hour. If we're in one of those + // situations, wait 100 milliseconds until we're in the next hour. + if got, want := now.Add(100*time.Millisecond).Truncate(time.Hour), now.Truncate(time.Hour); !got.Equal(want) { + time.Sleep(100 * time.Millisecond) + } + + data := []meta.DatabaseInfo{ + { + Name: "db0", + DefaultRetentionPolicy: "rp0", + RetentionPolicies: []meta.RetentionPolicyInfo{ + { + Name: "rp0", + ReplicaN: 1, + Duration: time.Hour, + ShardGroupDuration: time.Hour, + ShardGroups: []meta.ShardGroupInfo{ + { + ID: 1, + StartTime: now.Truncate(time.Hour).Add(-2 * time.Hour), + EndTime: now.Truncate(time.Hour).Add(-1 * time.Hour), + Shards: []meta.ShardInfo{ + {ID: 2}, + {ID: 3}, + }, + }, + { + ID: 4, + StartTime: now.Truncate(time.Hour).Add(-1 * time.Hour), + EndTime: now.Truncate(time.Hour), + Shards: []meta.ShardInfo{ + {ID: 5}, + {ID: 6}, + }, + }, + { + ID: 7, + StartTime: now.Truncate(time.Hour), + EndTime: now.Truncate(time.Hour).Add(time.Hour), + Shards: []meta.ShardInfo{ + {ID: 8}, + {ID: 9}, + }, + }, + }, + }, + }, + }, + } + + config := retention.NewConfig() + config.CheckInterval = toml.Duration(10 * time.Millisecond) + s := NewService(config) + s.MetaClient.DatabasesFn = func() []meta.DatabaseInfo { + return data + } + + done := make(chan struct{}) + deletedShardGroups := make(map[string]struct{}) + s.MetaClient.DeleteShardGroupFn = func(database, policy string, id uint64) error { + for _, dbi := range data { + if dbi.Name == database { + for _, rpi := range dbi.RetentionPolicies { + if rpi.Name == policy { + for i, sg := range rpi.ShardGroups { + if sg.ID == id { + rpi.ShardGroups[i].DeletedAt = time.Now().UTC() + } + } + } + } + } + } + + deletedShardGroups[fmt.Sprintf("%s.%s.%d", database, policy, id)] = struct{}{} + if got, want := deletedShardGroups, map[string]struct{}{ + "db0.rp0.1": struct{}{}, + }; reflect.DeepEqual(got, want) { + close(done) + } else if len(got) > 1 { + t.Errorf("deleted too many shard groups") + } + return nil + } + + pruned := false + closing := make(chan struct{}) + s.MetaClient.PruneShardGroupsFn = func() error { + select { + case <-done: + if !pruned { + close(closing) + pruned = true + } + default: + } + return nil + } + + deletedShards := make(map[uint64]struct{}) + s.TSDBStore.ShardIDsFn = func() []uint64 { + return []uint64{2, 3, 5, 6} + } + s.TSDBStore.DeleteShardFn = func(shardID uint64) error { + deletedShards[shardID] = struct{}{} + return nil + } + + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer func() { + if err := s.Close(); err != nil { + t.Fatalf("unexpected close error: %s", err) + } + }() + + timer := time.NewTimer(100 * time.Millisecond) + select { + case <-done: + timer.Stop() + case <-timer.C: + t.Errorf("timeout waiting for shard groups to be deleted") + return + } + + timer = time.NewTimer(100 * time.Millisecond) + select { + case <-closing: + timer.Stop() + case <-timer.C: + t.Errorf("timeout waiting for shards to be deleted") + return + } + + if got, want := deletedShards, map[uint64]struct{}{ + 2: struct{}{}, + 3: struct{}{}, + }; !reflect.DeepEqual(got, want) { + t.Errorf("unexpected deleted shards: got=%#v want=%#v", got, want) + } +} + // This reproduces https://github.com/influxdata/influxdb/issues/8819 func TestService_8819_repro(t *testing.T) { for i := 0; i < 1000; i++ { diff --git a/services/snapshotter/client_test.go b/services/snapshotter/client_test.go new file mode 100644 index 0000000000..ee5c5f2a73 --- /dev/null +++ b/services/snapshotter/client_test.go @@ -0,0 +1,83 @@ +package snapshotter_test + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "net" + "testing" + "time" + + "github.com/influxdata/influxdb" + "github.com/influxdata/influxdb/services/snapshotter" +) + +func TestClient_MetastoreBackup_InvalidMetadata(t *testing.T) { + metaBlob, err := data.MarshalBinary() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + nodeBytes, err := json.Marshal(&influxdb.Node{ID: 1}) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var numBytes [24]byte + + // Write an invalid magic header. + binary.BigEndian.PutUint64(numBytes[:8], snapshotter.BackupMagicHeader+1) + binary.BigEndian.PutUint64(numBytes[8:16], uint64(len(metaBlob))) + binary.BigEndian.PutUint64(numBytes[16:24], uint64(len(nodeBytes))) + + var buf bytes.Buffer + buf.Write(numBytes[:16]) + buf.Write(metaBlob) + buf.Write(numBytes[16:24]) + buf.Write(nodeBytes) + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer l.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := l.Accept() + if err != nil { + t.Errorf("error accepting tcp connection: %s", err) + return + } + defer conn.Close() + + var header [1]byte + if _, err := conn.Read(header[:]); err != nil { + t.Errorf("unable to read mux header: %s", err) + return + } + + var m map[string]interface{} + dec := json.NewDecoder(conn) + if err := dec.Decode(&m); err != nil { + t.Errorf("invalid json request: %s", err) + return + } + conn.Write(buf.Bytes()) + }() + + c := snapshotter.NewClient(l.Addr().String()) + _, err = c.MetastoreBackup() + if err == nil || err.Error() != "invalid metadata received" { + t.Errorf("unexpected error: got=%q want=%q", err, "invalid metadata received") + } + + timer := time.NewTimer(100 * time.Millisecond) + select { + case <-done: + timer.Stop() + case <-timer.C: + t.Errorf("timeout while waiting for the goroutine") + } +} diff --git a/services/snapshotter/service.go b/services/snapshotter/service.go index 2876aa61eb..5abde36217 100644 --- a/services/snapshotter/service.go +++ b/services/snapshotter/service.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "io" "net" "strings" "sync" @@ -29,8 +30,7 @@ const ( // Service manages the listener for the snapshot endpoint. type Service struct { - wg sync.WaitGroup - err chan error + wg sync.WaitGroup Node *influxdb.Node @@ -39,7 +39,11 @@ type Service struct { Database(name string) *meta.DatabaseInfo } - TSDBStore *tsdb.Store + TSDBStore interface { + BackupShard(id uint64, since time.Time, w io.Writer) error + Shard(id uint64) *tsdb.Shard + ShardRelativePath(id uint64) (string, error) + } Listener net.Listener Logger *zap.Logger @@ -48,7 +52,6 @@ type Service struct { // NewService returns a new instance of Service. func NewService() *Service { return &Service{ - err: make(chan error), Logger: zap.NewNop(), } } @@ -76,9 +79,6 @@ func (s *Service) WithLogger(log *zap.Logger) { s.Logger = log.With(zap.String("service", "snapshot")) } -// Err returns a channel for fatal out-of-band errors. -func (s *Service) Err() <-chan error { return s.err } - // serve serves snapshot requests from the listener. func (s *Service) serve() { defer s.wg.Done() @@ -198,7 +198,7 @@ func (s *Service) writeDatabaseInfo(conn net.Conn, database string) error { } if err := json.NewEncoder(conn).Encode(res); err != nil { - return fmt.Errorf("encode resonse: %s", err.Error()) + return fmt.Errorf("encode response: %s", err.Error()) } return nil diff --git a/services/snapshotter/service_test.go b/services/snapshotter/service_test.go index 704407e3ca..ab23d71780 100644 --- a/services/snapshotter/service_test.go +++ b/services/snapshotter/service_test.go @@ -1 +1,432 @@ package snapshotter_test + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/influxdata/influxdb/internal" + "github.com/influxdata/influxdb/logger" + "github.com/influxdata/influxdb/services/meta" + "github.com/influxdata/influxdb/services/snapshotter" + "github.com/influxdata/influxdb/tcp" + "github.com/influxdata/influxdb/tsdb" + "github.com/influxdata/influxql" +) + +var data = meta.Data{ + Databases: []meta.DatabaseInfo{ + { + Name: "db0", + DefaultRetentionPolicy: "autogen", + RetentionPolicies: []meta.RetentionPolicyInfo{ + { + Name: "rp0", + ReplicaN: 1, + Duration: 24 * 7 * time.Hour, + ShardGroupDuration: 24 * time.Hour, + ShardGroups: []meta.ShardGroupInfo{ + { + ID: 1, + StartTime: time.Unix(0, 0).UTC(), + EndTime: time.Unix(0, 0).UTC().Add(24 * time.Hour), + Shards: []meta.ShardInfo{ + {ID: 2}, + }, + }, + }, + }, + { + Name: "autogen", + ReplicaN: 1, + ShardGroupDuration: 24 * 7 * time.Hour, + ShardGroups: []meta.ShardGroupInfo{ + { + ID: 3, + StartTime: time.Unix(0, 0).UTC(), + EndTime: time.Unix(0, 0).UTC().Add(24 * time.Hour), + Shards: []meta.ShardInfo{ + {ID: 4}, + }, + }, + }, + }, + }, + }, + }, + Users: []meta.UserInfo{ + { + Name: "admin", + Hash: "abcxyz", + Admin: true, + Privileges: map[string]influxql.Privilege{}, + }, + }, +} + +func init() { + // Set the admin privilege on the user using this method so the meta.Data's check for + // an admin user is set properly. + data.SetAdminPrivilege("admin", true) +} + +func TestSnapshotter_Open(t *testing.T) { + s, l, err := NewTestService() + if err != nil { + t.Fatal(err) + } + defer l.Close() + + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + + if err := s.Close(); err != nil { + t.Fatalf("unexpected close error: %s", err) + } +} + +func TestSnapshotter_RequestShardBackup(t *testing.T) { + s, l, err := NewTestService() + if err != nil { + t.Fatal(err) + } + defer l.Close() + + var tsdb internal.TSDBStoreMock + tsdb.BackupShardFn = func(id uint64, since time.Time, w io.Writer) error { + if id != 5 { + t.Errorf("unexpected shard id: got=%#v want=%#v", id, 5) + } + if got, want := since, time.Unix(0, 0).UTC(); !got.Equal(want) { + t.Errorf("unexpected time since: got=%#v want=%#v", got, want) + } + // Write some nonsense data so we can check that it gets returned. + w.Write([]byte(`{"status":"ok"}`)) + return nil + } + s.TSDBStore = &tsdb + + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + defer conn.Close() + + req := snapshotter.Request{ + Type: snapshotter.RequestShardBackup, + ShardID: 5, + Since: time.Unix(0, 0), + } + conn.Write([]byte{snapshotter.MuxHeader}) + enc := json.NewEncoder(conn) + if err := enc.Encode(&req); err != nil { + t.Errorf("unable to encode request: %s", err) + return + } + + // Read the result. + out, err := ioutil.ReadAll(conn) + if err != nil { + t.Errorf("unexpected error reading shard backup: %s", err) + return + } + + if got, want := string(out), `{"status":"ok"}`; got != want { + t.Errorf("unexpected shard data: got=%#v want=%#v", got, want) + return + } +} + +func TestSnapshotter_RequestMetastoreBackup(t *testing.T) { + s, l, err := NewTestService() + if err != nil { + t.Fatal(err) + } + defer l.Close() + + s.MetaClient = &MetaClient{Data: data} + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + defer conn.Close() + + c := snapshotter.NewClient(l.Addr().String()) + if got, err := c.MetastoreBackup(); err != nil { + t.Errorf("unable to obtain metastore backup: %s", err) + return + } else if want := &data; !reflect.DeepEqual(got, want) { + t.Errorf("unexpected data backup:\n\ngot=%s\nwant=%s", spew.Sdump(got), spew.Sdump(want)) + return + } +} + +func TestSnapshotter_RequestDatabaseInfo(t *testing.T) { + s, l, err := NewTestService() + if err != nil { + t.Fatal(err) + } + defer l.Close() + + var tsdbStore internal.TSDBStoreMock + tsdbStore.ShardFn = func(id uint64) *tsdb.Shard { + if id != 2 && id != 4 { + t.Errorf("unexpected shard id: %d", id) + return nil + } else if id == 4 { + return nil + } + return &tsdb.Shard{} + } + tsdbStore.ShardRelativePathFn = func(id uint64) (string, error) { + if id == 2 { + return "db0/rp0", nil + } else if id == 4 { + t.Errorf("unexpected relative path request for shard id: %d", id) + } + return "", fmt.Errorf("no such shard id: %d", id) + } + + s.MetaClient = &MetaClient{Data: data} + s.TSDBStore = &tsdbStore + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + defer conn.Close() + + req := snapshotter.Request{ + Type: snapshotter.RequestDatabaseInfo, + Database: "db0", + } + conn.Write([]byte{snapshotter.MuxHeader}) + enc := json.NewEncoder(conn) + if err := enc.Encode(&req); err != nil { + t.Errorf("unable to encode request: %s", err) + return + } + + // Read the result. + out, err := ioutil.ReadAll(conn) + if err != nil { + t.Errorf("unexpected error reading database info: %s", err) + return + } + + // Unmarshal the response. + var resp snapshotter.Response + if err := json.Unmarshal(out, &resp); err != nil { + t.Errorf("error unmarshaling response: %s", err) + return + } + + if got, want := resp.Paths, []string{"db0/rp0"}; !reflect.DeepEqual(got, want) { + t.Errorf("unexpected paths: got=%#v want=%#v", got, want) + } +} + +func TestSnapshotter_RequestDatabaseInfo_ErrDatabaseNotFound(t *testing.T) { + s, l, err := NewTestService() + if err != nil { + t.Fatal(err) + } + defer l.Close() + + s.MetaClient = &MetaClient{Data: data} + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + defer conn.Close() + + req := snapshotter.Request{ + Type: snapshotter.RequestDatabaseInfo, + Database: "doesnotexist", + } + conn.Write([]byte{snapshotter.MuxHeader}) + enc := json.NewEncoder(conn) + if err := enc.Encode(&req); err != nil { + t.Errorf("unable to encode request: %s", err) + return + } + + // Read the result. + out, err := ioutil.ReadAll(conn) + if err != nil { + t.Errorf("unexpected error reading database info: %s", err) + return + } + + // There should be no response. + if got, want := string(out), ""; got != want { + t.Errorf("expected no message, got: %s", got) + } +} + +func TestSnapshotter_RequestRetentionPolicyInfo(t *testing.T) { + s, l, err := NewTestService() + if err != nil { + t.Fatal(err) + } + defer l.Close() + + var tsdbStore internal.TSDBStoreMock + tsdbStore.ShardFn = func(id uint64) *tsdb.Shard { + if id != 2 { + t.Errorf("unexpected shard id: %d", id) + return nil + } + return &tsdb.Shard{} + } + tsdbStore.ShardRelativePathFn = func(id uint64) (string, error) { + if id == 2 { + return "db0/rp0", nil + } + return "", fmt.Errorf("no such shard id: %d", id) + } + + s.MetaClient = &MetaClient{Data: data} + s.TSDBStore = &tsdbStore + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + defer conn.Close() + + req := snapshotter.Request{ + Type: snapshotter.RequestRetentionPolicyInfo, + Database: "db0", + RetentionPolicy: "rp0", + } + conn.Write([]byte{snapshotter.MuxHeader}) + enc := json.NewEncoder(conn) + if err := enc.Encode(&req); err != nil { + t.Errorf("unable to encode request: %s", err) + return + } + + // Read the result. + out, err := ioutil.ReadAll(conn) + if err != nil { + t.Errorf("unexpected error reading database info: %s", err) + return + } + + // Unmarshal the response. + var resp snapshotter.Response + if err := json.Unmarshal(out, &resp); err != nil { + t.Errorf("error unmarshaling response: %s", err) + return + } + + if got, want := resp.Paths, []string{"db0/rp0"}; !reflect.DeepEqual(got, want) { + t.Errorf("unexpected paths: got=%#v want=%#v", got, want) + } +} + +func TestSnapshotter_InvalidRequest(t *testing.T) { + s, l, err := NewTestService() + if err != nil { + t.Fatal(err) + } + defer l.Close() + + if err := s.Open(); err != nil { + t.Fatalf("unexpected open error: %s", err) + } + defer s.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + defer conn.Close() + + conn.Write([]byte{snapshotter.MuxHeader}) + conn.Write([]byte(`["invalid request"]`)) + + // Read the result. + out, err := ioutil.ReadAll(conn) + if err != nil { + t.Errorf("unexpected error reading database info: %s", err) + return + } + + // There should be no response. + if got, want := string(out), ""; got != want { + t.Errorf("expected no message, got: %s", got) + } +} + +func NewTestService() (*snapshotter.Service, net.Listener, error) { + s := snapshotter.NewService() + s.WithLogger(logger.New(os.Stderr)) + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + + // The snapshotter needs to be used with a tcp.Mux listener. + mux := tcp.NewMux() + go mux.Serve(l) + + s.Listener = mux.Listen(snapshotter.MuxHeader) + return s, l, nil +} + +type MetaClient struct { + Data meta.Data +} + +func (m *MetaClient) MarshalBinary() ([]byte, error) { + return m.Data.MarshalBinary() +} + +func (m *MetaClient) Database(name string) *meta.DatabaseInfo { + for _, dbi := range m.Data.Databases { + if dbi.Name == name { + return &dbi + } + } + return nil +} diff --git a/tcp/mux.go b/tcp/mux.go index e7b901a8e0..e45d69cf28 100644 --- a/tcp/mux.go +++ b/tcp/mux.go @@ -81,9 +81,13 @@ func (mux *Mux) Serve(ln net.Listener) error { if err != nil { // Wait for all connections to be demux mux.wg.Wait() + + mux.mu.Lock() for _, ln := range mux.m { close(ln.c) } + mux.m = nil + mux.mu.Unlock() if mux.defaultListener != nil { close(mux.defaultListener.c) @@ -169,6 +173,20 @@ func (mux *Mux) Listen(header byte) net.Listener { return ln } +// release removes the listener from the mux. +func (mux *Mux) release(ln *listener) bool { + mux.mu.Lock() + defer mux.mu.Unlock() + + for b, l := range mux.m { + if l == ln { + delete(mux.m, b) + return true + } + } + return false +} + // DefaultListener will return a net.Listener that will pass-through any // connections with non-registered values for the first byte of the connection. // The connections returned from this listener's Accept() method will replay the @@ -203,8 +221,13 @@ func (ln *listener) Accept() (c net.Conn, err error) { return conn, nil } -// Close is a no-op. The mux's listener should be closed instead. -func (ln *listener) Close() error { return nil } +// Close removes this listener from the parent mux and closes the channel. +func (ln *listener) Close() error { + if ok := ln.mux.release(ln); ok { + close(ln.c) + } + return nil +} // Addr returns the Addr of the listener func (ln *listener) Addr() net.Addr { diff --git a/tcp/mux_test.go b/tcp/mux_test.go index 78c041eb57..0b2930b69a 100644 --- a/tcp/mux_test.go +++ b/tcp/mux_test.go @@ -154,3 +154,59 @@ func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) { mux.Listen(5) mux.Listen(5) } + +// Ensure that closing a listener from mux.Listen releases an Accept call and +// deregisters the mux. +func TestMux_Close(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + done := make(chan struct{}) + mux := tcp.NewMux() + go func() { + mux.Serve(listener) + close(done) + }() + l := mux.Listen(5) + + closed := make(chan struct{}) + go func() { + _, err := l.Accept() + if err == nil || !strings.Contains(err.Error(), "connection closed") { + t.Errorf("unexpected error: %s", err) + } + close(closed) + }() + l.Close() + + timer := time.NewTimer(100 * time.Millisecond) + select { + case <-closed: + timer.Stop() + case <-timer.C: + t.Errorf("timeout while waiting for the mux to close") + } + + // We should now be able to register a new listener at the same byte + // without causing a panic. + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected recover: %#v", r) + } + }() + l = mux.Listen(5) + + // Verify that closing the listener does not cause a panic. + listener.Close() + timer = time.NewTimer(100 * time.Millisecond) + select { + case <-done: + timer.Stop() + // This should not panic. + l.Close() + case <-timer.C: + t.Errorf("timeout while waiting for the mux to close") + } +}