fix: lock MeasurementFields while validating (#25998)

There was a window where a race between writes with
differing types for the same field were being validated.
Lock the  MeasurementFields struct during field
validation to avoid this.

closes https://github.com/influxdata/influxdb/issues/23756
pull/26022/head
davidby-influx 2025-02-13 11:33:34 -08:00 committed by GitHub
parent 4ad5e2aba7
commit 5a20a835a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 247 additions and 95 deletions

View File

@ -1290,7 +1290,7 @@ func (e *Engine) addToIndexFromKey(keys [][]byte, fieldTypes []influxql.DataType
keys[i], field = SeriesAndFieldFromCompositeKey(keys[i])
name := models.ParseName(keys[i])
mf := e.fieldset.CreateFieldsIfNotExists(name)
if err := mf.CreateFieldIfNotExists(field, fieldTypes[i]); err != nil {
if _, err := mf.CreateFieldIfNotExists(field, fieldTypes[i]); err != nil {
return err
}

View File

@ -13,9 +13,11 @@ const MaxFieldValueLength = 1048576
// ValidateFields will return a PartialWriteError if:
// - the point has inconsistent fields, or
// - the point has fields that are too long
func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidation bool) error {
func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidation bool) ([]*FieldCreate, error) {
pointSize := point.StringSize()
iter := point.FieldIterator()
var fieldsToCreate []*FieldCreate
for iter.Next() {
if !skipSizeValidation {
// Check for size of field too large. Note it is much cheaper to check the whole point size
@ -23,7 +25,7 @@ func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidatio
// unescape the string, and must at least parse the string)
if pointSize > MaxFieldValueLength && iter.Type() == models.String {
if sz := len(iter.StringValue()); sz > MaxFieldValueLength {
return PartialWriteError{
return nil, PartialWriteError{
Reason: fmt.Sprintf(
"input field \"%s\" on measurement \"%s\" is too long, %d > %d",
iter.FieldKey(), point.Name(), sz, MaxFieldValueLength),
@ -33,14 +35,9 @@ func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidatio
}
}
fieldKey := iter.FieldKey()
// Skip fields name "time", they are illegal.
if bytes.Equal(iter.FieldKey(), timeBytes) {
continue
}
// If the fields is not present, there cannot be a conflict.
f := mf.FieldBytes(iter.FieldKey())
if f == nil {
if bytes.Equal(fieldKey, timeBytes) {
continue
}
@ -49,18 +46,26 @@ func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidatio
continue
}
// If the types are not the same, there is a conflict.
if f.Type != dataType {
return PartialWriteError{
// If the field is not present, remember to create it.
f := mf.FieldBytes(fieldKey)
if f == nil {
fieldsToCreate = append(fieldsToCreate, &FieldCreate{
Measurement: point.Name(),
Field: &Field{
Name: string(fieldKey),
Type: dataType,
}})
} else if f.Type != dataType {
// If the types are not the same, there is a conflict.
return nil, PartialWriteError{
Reason: fmt.Sprintf(
"%s: input field \"%s\" on measurement \"%s\" is type %s, already exists as type %s",
ErrFieldTypeConflict, iter.FieldKey(), point.Name(), dataType, f.Type),
ErrFieldTypeConflict, fieldKey, point.Name(), dataType, f.Type),
Dropped: 1,
}
}
}
return nil
return fieldsToCreate, nil
}
// dataTypeFromModelsFieldType returns the influxql.DataType that corresponds to the

View File

@ -572,13 +572,13 @@ func (s *Shard) WritePoints(points []models.Point, tracker StatsTracker) error {
// to the caller, but continue on writing the remaining points.
writeError = err
}
atomic.AddInt64(&s.stats.FieldsCreated, int64(len(fieldsToCreate)))
// add any new fields and keep track of what needs to be saved
if err := s.createFieldsAndMeasurements(fieldsToCreate); err != nil {
if numFieldsCreated, err := s.createFieldsAndMeasurements(fieldsToCreate); err != nil {
return err
} else {
atomic.AddInt64(&s.stats.FieldsCreated, int64(numFieldsCreated))
}
engineTracker := tracker
engineTracker.AddedPoints = func(points, values int64) {
if tracker.AddedPoints != nil {
@ -697,61 +697,44 @@ func (s *Shard) validateSeriesAndFields(points []models.Point, tracker StatsTrac
continue
}
// Skip any points whos keys have been dropped. Dropped has already been incremented for them.
// Skip any points whose keys have been dropped. Dropped has already been incremented for them.
if len(droppedKeys) > 0 && bytesutil.Contains(droppedKeys, keys[i]) {
continue
}
name := p.Name()
mf := engine.MeasurementFields(name)
// Check with the field validator.
if err := ValidateFields(mf, p, s.options.Config.SkipFieldSizeValidation); err != nil {
switch err := err.(type) {
case PartialWriteError:
if reason == "" {
reason = err.Reason
err := func(p models.Point, iter models.FieldIterator) error {
var newFields []*FieldCreate
var validateErr error
name := p.Name()
mf := engine.MeasurementFields(name)
mf.mu.RLock()
defer mf.mu.RUnlock()
// Check with the field validator.
if newFields, validateErr = ValidateFields(mf, p, s.options.Config.SkipFieldSizeValidation); validateErr != nil {
var err PartialWriteError
switch {
case errors.As(validateErr, &err):
// This will turn into an error later, outside this lambda
if reason == "" {
reason = err.Reason
}
dropped += err.Dropped
atomic.AddInt64(&s.stats.WritePointsDropped, int64(err.Dropped))
default:
return err
}
dropped += err.Dropped
atomic.AddInt64(&s.stats.WritePointsDropped, int64(err.Dropped))
default:
return nil, nil, err
}
continue
}
points[j] = points[i]
j++
// Create any fields that are missing.
iter.Reset()
for iter.Next() {
fieldKey := iter.FieldKey()
// Skip fields named "time". They are illegal.
if bytes.Equal(fieldKey, timeBytes) {
continue
return nil
}
if mf.FieldBytes(fieldKey) != nil {
continue
}
dataType := dataTypeFromModelsFieldType(iter.Type())
if dataType == influxql.Unknown {
continue
}
fieldsToCreate = append(fieldsToCreate, &FieldCreate{
Measurement: name,
Field: &Field{
Name: string(fieldKey),
Type: dataType,
},
})
points[j] = points[i]
j++
fieldsToCreate = append(fieldsToCreate, newFields...)
return nil
}(p, iter)
if err != nil {
return nil, nil, err
}
}
if dropped > 0 {
err = PartialWriteError{Reason: reason, Dropped: dropped, Database: s.database, RetentionPolicy: s.retentionPolicy}
}
@ -781,31 +764,33 @@ func makePrintable(s string) string {
return b.String()
}
func (s *Shard) createFieldsAndMeasurements(fieldsToCreate []*FieldCreate) error {
func (s *Shard) createFieldsAndMeasurements(fieldsToCreate []*FieldCreate) (int, error) {
if len(fieldsToCreate) == 0 {
return nil
return 0, nil
}
engine, err := s.engineNoLock()
if err != nil {
return err
return 0, err
}
numCreated := 0
// add fields
changes := make([]*FieldChange, 0, len(fieldsToCreate))
for _, f := range fieldsToCreate {
mf := engine.MeasurementFields(f.Measurement)
if err := mf.CreateFieldIfNotExists([]byte(f.Field.Name), f.Field.Type); err != nil {
return err
if created, err := mf.CreateFieldIfNotExists([]byte(f.Field.Name), f.Field.Type); err != nil {
return 0, err
} else if created {
numCreated++
s.index.SetFieldName(f.Measurement, f.Field.Name)
changes = append(changes, &FieldChange{
FieldCreate: *f,
ChangeType: AddMeasurementField,
})
}
s.index.SetFieldName(f.Measurement, f.Field.Name)
changes = append(changes, &FieldChange{
FieldCreate: *f,
ChangeType: AddMeasurementField,
})
}
return engine.MeasurementFieldSet().Save(changes)
return numCreated, engine.MeasurementFieldSet().Save(changes)
}
// DeleteSeriesRange deletes all values from for seriesKeys between min and max (inclusive)
@ -1577,7 +1562,7 @@ func (a Shards) ExpandSources(sources influxql.Sources) (influxql.Sources, error
// MeasurementFields holds the fields of a measurement and their codec.
type MeasurementFields struct {
mu sync.Mutex
mu sync.RWMutex
fields atomic.Value // map[string]*Field
}
@ -1616,15 +1601,15 @@ func (m *MeasurementFields) bytes() int {
// CreateFieldIfNotExists creates a new field with an autoincrementing ID.
// Returns an error if 255 fields have already been created on the measurement or
// the fields already exists with a different type.
func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.DataType) error {
func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.DataType) (bool, error) {
fields := m.fields.Load().(map[string]*Field)
// Ignore if the field already exists.
if f := fields[string(name)]; f != nil {
if f.Type != typ {
return ErrFieldTypeConflict
return false, ErrFieldTypeConflict
}
return nil
return false, nil
}
m.mu.Lock()
@ -1634,9 +1619,9 @@ func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.Dat
// Re-check field and type under write lock.
if f := fields[string(name)]; f != nil {
if f.Type != typ {
return ErrFieldTypeConflict
return false, ErrFieldTypeConflict
}
return nil
return false, nil
}
fieldsUpdate := make(map[string]*Field, len(fields)+1)
@ -1652,7 +1637,7 @@ func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.Dat
fieldsUpdate[string(name)] = f
m.fields.Store(fieldsUpdate)
return nil
return true, nil
}
func (m *MeasurementFields) FieldN() int {
@ -2325,7 +2310,7 @@ func (fs *MeasurementFieldSet) ApplyChanges() error {
fs.Delete(string(fc.Measurement))
} else {
mf := fs.CreateFieldsIfNotExists(fc.Measurement)
if err := mf.CreateFieldIfNotExists([]byte(fc.Field.Name), fc.Field.Type); err != nil {
if _, err := mf.CreateFieldIfNotExists([]byte(fc.Field.Name), fc.Field.Type); err != nil {
err = fmt.Errorf("failed creating %q.%q: %w", fc.Measurement, fc.Field.Name, err)
log.Error("field creation", zap.Error(err))
return err

View File

@ -7,6 +7,7 @@ import (
"fmt"
"math"
"os"
"path"
"path/filepath"
"reflect"
"regexp"
@ -14,12 +15,10 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
assert2 "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/davecgh/go-spew/spew"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
@ -30,9 +29,12 @@ import (
"github.com/influxdata/influxdb/query"
"github.com/influxdata/influxdb/tsdb"
_ "github.com/influxdata/influxdb/tsdb/engine"
"github.com/influxdata/influxdb/tsdb/engine/tsm1"
_ "github.com/influxdata/influxdb/tsdb/index"
"github.com/influxdata/influxdb/tsdb/index/inmem"
"github.com/influxdata/influxql"
assert2 "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShardWriteAndIndex(t *testing.T) {
@ -52,7 +54,8 @@ func TestShardWriteAndIndex(t *testing.T) {
// Calling WritePoints when the engine is not open will return
// ErrEngineClosed.
if got, exp := sh.WritePoints(nil, tsdb.NoopStatsTracker()), tsdb.ErrEngineClosed; got != exp {
got := sh.WritePoints(nil, tsdb.NoopStatsTracker())
if exp := tsdb.ErrEngineClosed; got != exp {
t.Fatalf("got %v, expected %v", got, exp)
}
@ -122,7 +125,8 @@ func TestShard_Open_CorruptFieldsIndex(t *testing.T) {
// Calling WritePoints when the engine is not open will return
// ErrEngineClosed.
if got, exp := sh.WritePoints(nil, tsdb.NoopStatsTracker()), tsdb.ErrEngineClosed; got != exp {
got := sh.WritePoints(nil, tsdb.NoopStatsTracker())
if exp := tsdb.ErrEngineClosed; got != exp {
t.Fatalf("got %v, expected %v", got, exp)
}
@ -1687,7 +1691,7 @@ func TestMeasurementFieldSet_SaveLoad(t *testing.T) {
}
defer checkMeasurementFieldSetClose(t, mf)
fields := mf.CreateFieldsIfNotExists([]byte(measurement))
if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
t.Fatalf("create field error: %v", err)
}
change := tsdb.FieldChange{
@ -1739,7 +1743,7 @@ func TestMeasurementFieldSet_Corrupt(t *testing.T) {
measurement := []byte("cpu")
fields := mf.CreateFieldsIfNotExists(measurement)
fieldName := "value"
if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
t.Fatalf("create field error: %v", err)
}
change := tsdb.FieldChange{
@ -1810,7 +1814,7 @@ func TestMeasurementFieldSet_CorruptChangeFile(t *testing.T) {
defer checkMeasurementFieldSetClose(t, mf)
for _, f := range testFields {
fields := mf.CreateFieldsIfNotExists([]byte(f.Measurement))
if err := fields.CreateFieldIfNotExists([]byte(f.Field), f.FieldType); err != nil {
if _, err := fields.CreateFieldIfNotExists([]byte(f.Field), f.FieldType); err != nil {
t.Fatalf("create field error: %v", err)
}
change := tsdb.FieldChange{
@ -1872,7 +1876,7 @@ func TestMeasurementFieldSet_DeleteEmpty(t *testing.T) {
defer checkMeasurementFieldSetClose(t, mf)
fields := mf.CreateFieldsIfNotExists([]byte(measurement))
if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
t.Fatalf("create field error: %v", err)
}
@ -2005,7 +2009,7 @@ func testFieldMaker(t *testing.T, wg *sync.WaitGroup, mf *tsdb.MeasurementFieldS
fields := mf.CreateFieldsIfNotExists([]byte(measurement))
for _, fieldName := range fieldNames {
if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil {
t.Logf("create field error: %v", err)
t.Fail()
return
@ -2655,3 +2659,161 @@ func (a seriesIDSets) ForEach(f func(ids *tsdb.SeriesIDSet)) error {
}
return nil
}
// Tests concurrently writing to the same shard with different field types which
// can trigger a panic when the shard is snapshotted to TSM files.
func TestShard_WritePoints_ForceFieldConflictConcurrent(t *testing.T) {
const Runs = 50
if testing.Short() || runtime.GOOS == "windows" {
t.Skip("Skipping on short or windows")
}
for i := 0; i < Runs; i++ {
conflictShard(t, i)
}
}
func conflictShard(t *testing.T, run int) {
const measurement = "cpu"
const field = "value"
const numTypes = 4 // float, int, bool, string
const pointCopies = 10
const trialsPerShard = 10
tmpDir, _ := os.MkdirTemp("", "shard_test")
defer func() {
require.NoError(t, os.RemoveAll(tmpDir), "removing %s", tmpDir)
}()
tmpShard := filepath.Join(tmpDir, "shard")
tmpWal := filepath.Join(tmpDir, "wal")
sfile := MustOpenSeriesFile()
defer func() {
require.NoError(t, sfile.Close(), "closing series file")
require.NoError(t, os.RemoveAll(sfile.Path()), "removing series file %s", sfile.Path())
}()
opts := tsdb.NewEngineOptions()
opts.Config.WALDir = tmpWal
opts.InmemIndex = inmem.NewIndex(filepath.Base(tmpDir), sfile.SeriesFile)
opts.SeriesIDSets = seriesIDSets([]*tsdb.SeriesIDSet{})
sh := tsdb.NewShard(1, tmpShard, tmpWal, sfile.SeriesFile, opts)
require.NoError(t, sh.Open(), "opening shard: %s", sh.Path())
defer func() {
require.NoError(t, sh.Close(), "closing shard %s", tmpShard)
}()
var wg sync.WaitGroup
mu := sync.RWMutex{}
maxConcurrency := atomic.Int64{}
currentTime := time.Now()
points := make([]models.Point, 0, pointCopies*numTypes)
for i := 0; i < pointCopies; i++ {
points = append(points, models.MustNewPoint(
measurement,
models.NewTags(map[string]string{"host": "server"}),
map[string]interface{}{field: 1.0},
currentTime.Add(time.Duration(i)*time.Second),
))
points = append(points, models.MustNewPoint(
measurement,
models.NewTags(map[string]string{"host": "server"}),
map[string]interface{}{field: int64(1)},
currentTime.Add(time.Duration(i)*time.Second),
))
points = append(points, models.MustNewPoint(
measurement,
models.NewTags(map[string]string{"host": "server"}),
map[string]interface{}{field: "one"},
currentTime.Add(time.Duration(i)*time.Second),
))
points = append(points, models.MustNewPoint(
measurement,
models.NewTags(map[string]string{"host": "server"}),
map[string]interface{}{field: true},
currentTime.Add(time.Duration(i)*time.Second),
))
}
concurrency := atomic.Int64{}
for i := 0; i < trialsPerShard; i++ {
mu.Lock()
wg.Add(len(points))
// Write points concurrently
for i := 0; i < pointCopies; i++ {
for j := 0; j < numTypes; j++ {
concurrency.Add(1)
go func(mp models.Point) {
mu.RLock()
defer concurrency.Add(-1)
defer mu.RUnlock()
defer wg.Done()
if err := sh.WritePoints([]models.Point{mp}, tsdb.NoopStatsTracker()); err == nil {
fs, err := mp.Fields()
require.NoError(t, err, "getting fields")
require.Equal(t,
sh.MeasurementFields([]byte(measurement)).Field(field).Type,
influxql.InspectDataType(fs[field]),
"field types mismatch on run %d: types exp: %s, got: %s", run+1, sh.MeasurementFields([]byte(measurement)).Field(field).Type.String(), influxql.InspectDataType(fs[field]).String())
} else {
require.ErrorContains(t, err, tsdb.ErrFieldTypeConflict.Error(), "unexpected error")
}
if c := concurrency.Load(); maxConcurrency.Load() < c {
maxConcurrency.Store(c)
}
}(points[i*numTypes+j])
}
}
mu.Unlock()
wg.Wait()
dir, err := sh.CreateSnapshot(false)
require.NoError(t, err, "creating snapshot: %s", sh.Path())
require.NoError(t, os.RemoveAll(dir), "removing snapshot directory %s", dir)
}
keyType := map[string]byte{}
files, err := os.ReadDir(tmpShard)
require.NoError(t, err, "reading shard directory %s", tmpShard)
for i, file := range files {
if !strings.HasSuffix(path.Ext(file.Name()), tsm1.TSMFileExtension) {
continue
}
ffile := path.Join(tmpShard, file.Name())
fh, err := os.Open(ffile)
require.NoError(t, err, "opening snapshot file %s", ffile)
tr, err := tsm1.NewTSMReader(fh)
require.NoError(t, err, "creating TSM reader for %s", ffile)
key, typ := tr.KeyAt(0)
if oldTyp, ok := keyType[string(key)]; ok {
require.Equal(t, oldTyp, typ,
"field type mismatch in run %d TSM file %d -- %q in %s\nfirst seen: %s, newest: %s, field type: %s",
run+1,
i+1,
string(key),
ffile,
blockTypeString(oldTyp),
blockTypeString(typ),
sh.MeasurementFields([]byte(measurement)).Field(field).Type.String())
} else {
keyType[string(key)] = typ
}
// Must close after all uses of key (mapped memory)
require.NoError(t, tr.Close(), "closing TSM reader")
}
// t.Logf("Type %s wins run %d with concurrency: %d", sh.MeasurementFields([]byte(measurement)).Field(field).Type.String(), run+1, maxConcurrency.Load())
}
func blockTypeString(typ byte) string {
switch typ {
case tsm1.BlockFloat64:
return "float64"
case tsm1.BlockInteger:
return "int64"
case tsm1.BlockBoolean:
return "bool"
case tsm1.BlockString:
return "string"
default:
return "unknown"
}
}