package proxy

import (
	"context"
	"errors"
	"fmt"
	"regexp"
	"strconv"
	"strings"
	"sync"

	"github.com/milvus-io/milvus/internal/parser/planparserv2"

	"github.com/golang/protobuf/proto"
	"go.uber.org/zap"
	"golang.org/x/sync/errgroup"

	"github.com/milvus-io/milvus/internal/log"
	"github.com/milvus-io/milvus/internal/metrics"
	"github.com/milvus-io/milvus/internal/types"
	"github.com/milvus-io/milvus/internal/util/timerecord"
	"github.com/milvus-io/milvus/internal/util/tsoutil"
	"github.com/milvus-io/milvus/internal/util/typeutil"

	"github.com/milvus-io/milvus/internal/proto/commonpb"
	"github.com/milvus-io/milvus/internal/proto/internalpb"
	"github.com/milvus-io/milvus/internal/proto/milvuspb"
	"github.com/milvus-io/milvus/internal/proto/querypb"
	"github.com/milvus-io/milvus/internal/proto/schemapb"
)

const (
	WithCache    = true
	WithoutCache = false
)

type queryTask struct {
	Condition
	*internalpb.RetrieveRequest

	ctx            context.Context
	result         *milvuspb.QueryResults
	request        *milvuspb.QueryRequest
	qc             types.QueryCoord
	ids            *schemapb.IDs
	collectionName string

	resultBuf       chan *internalpb.RetrieveResults
	toReduceResults []*internalpb.RetrieveResults
	runningGroup    *errgroup.Group
	runningGroupCtx context.Context

	getQueryNodePolicy getQueryNodePolicy
	queryShardPolicy   pickShardPolicy
}

func (t *queryTask) PreExecute(ctx context.Context) error {
	if t.getQueryNodePolicy == nil {
		t.getQueryNodePolicy = defaultGetQueryNodePolicy
	}

	if t.queryShardPolicy == nil {
		t.queryShardPolicy = roundRobinPolicy
	}

	t.Base.MsgType = commonpb.MsgType_Retrieve
	t.Base.SourceID = Params.ProxyCfg.GetNodeID()

	collectionName := t.request.CollectionName
	t.collectionName = collectionName
	if err := validateCollectionName(collectionName); err != nil {
		log.Warn("Invalid collection name.", zap.String("collectionName", collectionName),
			zap.Int64("requestID", t.Base.MsgID), zap.String("requestType", "query"))
		return err
	}

	log.Info("Validate collection name.", zap.Any("collectionName", collectionName),
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))

	collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
	if err != nil {
		log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
			zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
		return err
	}

	t.CollectionID = collID
	log.Info("Get collection ID by name",
		zap.Int64("collectionID", t.CollectionID), zap.String("collection name", collectionName),
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))

	for _, tag := range t.request.PartitionNames {
		if err := validatePartitionTag(tag, false); err != nil {
			log.Warn("invalid partition name", zap.String("partition name", tag),
				zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
			return err
		}
	}
	log.Debug("Validate partition names.",
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))

	t.PartitionIDs = make([]UniqueID, 0)
	partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
	if err != nil {
		log.Warn("failed to get partitions in collection.", zap.String("collection name", collectionName),
			zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
		return err
	}
	log.Debug("Get partitions in collection.", zap.Any("collectionName", collectionName),
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))

	// Check if partitions are valid partitions in collection
	partitionsRecord := make(map[UniqueID]bool)
	for _, partitionName := range t.request.PartitionNames {
		pattern := fmt.Sprintf("^%s$", partitionName)
		re, err := regexp.Compile(pattern)
		if err != nil {
			log.Debug("failed to compile partition name regex expression.", zap.Any("partition name", partitionName),
				zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
			return errors.New("invalid partition names")
		}
		found := false
		for name, pID := range partitionsMap {
			if re.MatchString(name) {
				if _, exist := partitionsRecord[pID]; !exist {
					t.PartitionIDs = append(t.PartitionIDs, pID)
					partitionsRecord[pID] = true
				}
				found = true
			}
		}
		if !found {
			// FIXME(wxyu): undefined behavior
			errMsg := fmt.Sprintf("partition name: %s not found", partitionName)
			return errors.New(errMsg)
		}
	}

	loaded, err := t.checkIfLoaded(collID, t.PartitionIDs)
	if err != nil {
		return fmt.Errorf("checkIfLoaded failed when query, collection:%v, partitions:%v, err = %s", collectionName, t.request.GetPartitionNames(), err)
	}
	if !loaded {
		return fmt.Errorf("collection:%v or partition:%v not loaded into memory when query", collectionName, t.request.GetPartitionNames())
	}

	schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName)

	if t.ids != nil {
		pkField := ""
		for _, field := range schema.Fields {
			if field.IsPrimaryKey {
				pkField = field.Name
			}
		}
		t.request.Expr = IDs2Expr(pkField, t.ids)
	}

	if t.request.Expr == "" {
		return fmt.Errorf("query expression is empty")
	}

	plan, err := planparserv2.CreateRetrievePlan(schema, t.request.Expr)
	if err != nil {
		return err
	}
	t.request.OutputFields, err = translateOutputFields(t.request.OutputFields, schema, true)
	if err != nil {
		return err
	}
	log.Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields),
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))

	if len(t.request.OutputFields) == 0 {
		for _, field := range schema.Fields {
			if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
				t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID)
			}
		}
	} else {
		addPrimaryKey := false
		for _, reqField := range t.request.OutputFields {
			findField := false
			for _, field := range schema.Fields {
				if reqField == field.Name {
					if field.IsPrimaryKey {
						addPrimaryKey = true
					}
					findField = true
					t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID)
					plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
				} else {
					if field.IsPrimaryKey && !addPrimaryKey {
						t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID)
						plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
						addPrimaryKey = true
					}
				}
			}
			if !findField {
				return fmt.Errorf("field %s not exist", reqField)
			}
		}
	}
	log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId),
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))

	t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan)
	if err != nil {
		return err
	}

	if t.request.TravelTimestamp == 0 {
		t.TravelTimestamp = t.BeginTs()
	} else {
		t.TravelTimestamp = t.request.TravelTimestamp
	}

	err = validateTravelTimestamp(t.TravelTimestamp, t.BeginTs())
	if err != nil {
		return err
	}

	guaranteeTs := t.request.GetGuaranteeTimestamp()
	t.GuaranteeTimestamp = parseGuaranteeTs(guaranteeTs, t.BeginTs())

	deadline, ok := t.TraceCtx().Deadline()
	if ok {
		t.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
	}

	t.DbID = 0 // TODO
	log.Info("Query PreExecute done.",
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
	return nil
}

func (t *queryTask) Execute(ctx context.Context) error {
	tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", t.ID()))
	defer tr.Elapse("done")

	executeQuery := func(withCache bool) error {
		shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
		if err != nil {
			return err
		}

		t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
		t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
		t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
		for channelID, leaders := range shards {
			channelID := channelID
			leaders := leaders
			t.runningGroup.Go(func() error {
				log.Debug("proxy starting to query one shard",
					zap.Int64("collectionID", t.CollectionID),
					zap.String("collection name", t.collectionName),
					zap.String("shard channel", channelID),
					zap.Uint64("timeoutTs", t.TimeoutTimestamp))

				err := t.queryShard(t.runningGroupCtx, leaders, channelID)
				if err != nil {
					return err
				}
				return nil
			})
		}

		err = t.runningGroup.Wait()
		return err
	}

	err := executeQuery(WithCache)
	if err == errInvalidShardLeaders {
		log.Warn("invalid shard leaders cache, updating shardleader caches and retry search")
		return executeQuery(WithoutCache)
	}
	if err != nil {
		return fmt.Errorf("fail to search on all shard leaders, err=%s", err.Error())
	}

	log.Info("Query Execute done.",
		zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
	return nil
}

func (t *queryTask) PostExecute(ctx context.Context) error {
	tr := timerecord.NewTimeRecorder("queryTask PostExecute")
	defer func() {
		tr.Elapse("done")
	}()

	var err error
	wg := sync.WaitGroup{}
	wg.Add(1)
	go func() {
		for {
			select {
			case <-t.TraceCtx().Done():
				log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, taskID:", t.ID()))
				return
			case <-t.runningGroupCtx.Done():
				log.Debug("all queries are finished or canceled", zap.Any("taskID", t.ID()))
				close(t.resultBuf)
				for res := range t.resultBuf {
					t.toReduceResults = append(t.toReduceResults, res)
					log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("taskID", t.ID()))
				}
				wg.Done()
				return
			}
		}
	}()

	wg.Wait()

	metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
	tr.Record("reduceResultStart")
	t.result, err = mergeRetrieveResults(t.toReduceResults)
	if err != nil {
		return err
	}
	metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
	t.result.CollectionName = t.collectionName

	if len(t.result.FieldsData) > 0 {
		t.result.Status = &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_Success,
		}
	} else {
		log.Info("Query result is nil", zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
		t.result.Status = &commonpb.Status{
			ErrorCode: commonpb.ErrorCode_EmptyCollection,
			Reason:    "emptly collection", // TODO
		}
		return nil
	}

	schema, err := globalMetaCache.GetCollectionSchema(ctx, t.request.CollectionName)
	if err != nil {
		return err
	}
	for i := 0; i < len(t.result.FieldsData); i++ {
		for _, field := range schema.Fields {
			if field.FieldID == t.OutputFieldsId[i] {
				t.result.FieldsData[i].FieldName = field.Name
				t.result.FieldsData[i].FieldId = field.FieldID
				t.result.FieldsData[i].Type = field.DataType
			}
		}
	}
	log.Info("Query PostExecute done", zap.Any("requestID", t.Base.MsgID), zap.String("requestType", "query"))
	return nil
}

func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
	query := func(nodeID UniqueID, qn types.QueryNode) error {
		req := &querypb.QueryRequest{
			Req:        t.RetrieveRequest,
			DmlChannel: channelID,
			Scope:      querypb.DataScope_All,
		}

		result, err := qn.Query(ctx, req)
		if err != nil || result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
			log.Warn("QueryNode query returns error", zap.Int64("nodeID", nodeID),
				zap.Error(err))
			return errInvalidShardLeaders
		}
		if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
			log.Warn("QueryNode query result error", zap.Int64("nodeID", nodeID),
				zap.String("reason", result.GetStatus().GetReason()))
			return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
		}

		log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
		t.resultBuf <- result
		return nil
	}

	err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
	if err != nil {
		log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
		return err
	}

	return nil
}

func (t *queryTask) checkIfLoaded(collectionID UniqueID, queryPartitionIDs []UniqueID) (bool, error) {
	// check if collection was loaded into QueryNode
	info, err := globalMetaCache.GetCollectionInfo(t.ctx, t.collectionName)
	if err != nil {
		return false, fmt.Errorf("GetCollectionInfo failed, collectionID = %d, err = %s", collectionID, err)
	}
	if info.isLoaded {
		return true, nil
	}

	// If request to query partitions
	if len(queryPartitionIDs) > 0 {
		resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{
			Base: &commonpb.MsgBase{
				MsgType:   commonpb.MsgType_ShowCollections,
				MsgID:     t.Base.MsgID,
				Timestamp: t.Base.Timestamp,
				SourceID:  Params.ProxyCfg.GetNodeID(),
			},
			CollectionID: collectionID,
			PartitionIDs: queryPartitionIDs,
		})
		if err != nil {
			return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, queryPartitionIDs, err)
		}
		if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
			return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, queryPartitionIDs, resp.GetStatus().GetReason())
		}
		// Current logic: show partitions won't return error if the given partitions are all loaded
		return true, nil
	}

	// If request to query collection and collection is not fully loaded
	resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{
		Base: &commonpb.MsgBase{
			MsgType:   commonpb.MsgType_ShowCollections,
			MsgID:     t.Base.MsgID,
			Timestamp: t.Base.Timestamp,
			SourceID:  Params.ProxyCfg.GetNodeID(),
		},
		CollectionID: collectionID,
	})
	if err != nil {
		return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, queryPartitionIDs, err)
	}
	if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
		return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, queryPartitionIDs, resp.GetStatus().GetReason())
	}

	if len(resp.GetPartitionIDs()) > 0 {
		log.Warn("collection not fully loaded, query on these partitions",
			zap.Int64("collectionID", collectionID),
			zap.Int64s("partitionIDs", resp.GetPartitionIDs()))
		return true, nil
	}

	return false, nil
}

// IDs2Expr converts ids slices to bool expresion with specified field name
func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
	var idsStr string
	switch ids.GetIdField().(type) {
	case *schemapb.IDs_IntId:
		idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
	case *schemapb.IDs_StrId:
		idsStr = strings.Trim(strings.Join(ids.GetStrId().GetData(), ", "), "[]")
	}

	return fieldName + " in [ " + idsStr + " ]"
}

func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
	var ret *milvuspb.QueryResults
	var skipDupCnt int64
	var idSet = make(map[interface{}]struct{})

	// merge results and remove duplicates
	for _, rr := range retrieveResults {
		numPks := typeutil.GetSizeOfIDs(rr.GetIds())
		// skip empty result, it will break merge result
		if rr == nil || rr.Ids == nil || rr.GetIds() == nil || numPks == 0 {
			continue
		}

		if ret == nil {
			ret = &milvuspb.QueryResults{
				FieldsData: make([]*schemapb.FieldData, len(rr.FieldsData)),
			}
		}

		if len(ret.FieldsData) != len(rr.FieldsData) {
			return nil, fmt.Errorf("mismatch FieldData in proxy RetrieveResults, expect %d get %d", len(ret.FieldsData), len(rr.FieldsData))
		}

		for i := 0; i < numPks; i++ {
			id := typeutil.GetPK(rr.GetIds(), int64(i))
			if _, ok := idSet[id]; !ok {
				typeutil.AppendFieldData(ret.FieldsData, rr.FieldsData, int64(i))
				idSet[id] = struct{}{}
			} else {
				// primary keys duplicate
				skipDupCnt++
			}
		}
	}
	log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))

	if ret == nil {
		ret = &milvuspb.QueryResults{
			FieldsData: []*schemapb.FieldData{},
		}
	}

	return ret, nil
}

func (t *queryTask) TraceCtx() context.Context {
	return t.ctx
}

func (t *queryTask) ID() UniqueID {
	return t.Base.MsgID
}

func (t *queryTask) SetID(uid UniqueID) {
	t.Base.MsgID = uid
}

func (t *queryTask) Name() string {
	return RetrieveTaskName
}

func (t *queryTask) Type() commonpb.MsgType {
	return t.Base.MsgType
}

func (t *queryTask) BeginTs() Timestamp {
	return t.Base.Timestamp
}

func (t *queryTask) EndTs() Timestamp {
	return t.Base.Timestamp
}

func (t *queryTask) SetTs(ts Timestamp) {
	t.Base.Timestamp = ts
}

func (t *queryTask) OnEnqueue() error {
	t.Base.MsgType = commonpb.MsgType_Retrieve
	return nil
}