package proxy

import (
	"context"

	"go.uber.org/zap"

	"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
	"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
	"github.com/milvus-io/milvus/internal/types"
	"github.com/milvus-io/milvus/pkg/log"
	"github.com/milvus-io/milvus/pkg/mq/msgstream"
	"github.com/milvus-io/milvus/pkg/util/commonpbutil"
	"github.com/milvus-io/milvus/pkg/util/merr"
	"github.com/milvus-io/milvus/pkg/util/paramtable"
)

type createDatabaseTask struct {
	baseTask
	Condition
	*milvuspb.CreateDatabaseRequest
	ctx       context.Context
	rootCoord types.RootCoordClient
	result    *commonpb.Status

	replicateMsgStream msgstream.MsgStream
}

func (cdt *createDatabaseTask) TraceCtx() context.Context {
	return cdt.ctx
}

func (cdt *createDatabaseTask) ID() UniqueID {
	return cdt.Base.MsgID
}

func (cdt *createDatabaseTask) SetID(uid UniqueID) {
	cdt.Base.MsgID = uid
}

func (cdt *createDatabaseTask) Name() string {
	return CreateDatabaseTaskName
}

func (cdt *createDatabaseTask) Type() commonpb.MsgType {
	return cdt.Base.MsgType
}

func (cdt *createDatabaseTask) BeginTs() Timestamp {
	return cdt.Base.Timestamp
}

func (cdt *createDatabaseTask) EndTs() Timestamp {
	return cdt.Base.Timestamp
}

func (cdt *createDatabaseTask) SetTs(ts Timestamp) {
	cdt.Base.Timestamp = ts
}

func (cdt *createDatabaseTask) OnEnqueue() error {
	if cdt.Base == nil {
		cdt.Base = commonpbutil.NewMsgBase()
	}
	cdt.Base.MsgType = commonpb.MsgType_CreateDatabase
	cdt.Base.SourceID = paramtable.GetNodeID()
	return nil
}

func (cdt *createDatabaseTask) PreExecute(ctx context.Context) error {
	return ValidateDatabaseName(cdt.GetDbName())
}

func (cdt *createDatabaseTask) Execute(ctx context.Context) error {
	var err error
	cdt.result, err = cdt.rootCoord.CreateDatabase(ctx, cdt.CreateDatabaseRequest)
	if cdt.result != nil && cdt.result.ErrorCode == commonpb.ErrorCode_Success {
		SendReplicateMessagePack(ctx, cdt.replicateMsgStream, cdt.CreateDatabaseRequest)
	}
	return err
}

func (cdt *createDatabaseTask) PostExecute(ctx context.Context) error {
	return nil
}

type dropDatabaseTask struct {
	baseTask
	Condition
	*milvuspb.DropDatabaseRequest
	ctx       context.Context
	rootCoord types.RootCoordClient
	result    *commonpb.Status

	replicateMsgStream msgstream.MsgStream
}

func (ddt *dropDatabaseTask) TraceCtx() context.Context {
	return ddt.ctx
}

func (ddt *dropDatabaseTask) ID() UniqueID {
	return ddt.Base.MsgID
}

func (ddt *dropDatabaseTask) SetID(uid UniqueID) {
	ddt.Base.MsgID = uid
}

func (ddt *dropDatabaseTask) Name() string {
	return DropCollectionTaskName
}

func (ddt *dropDatabaseTask) Type() commonpb.MsgType {
	return ddt.Base.MsgType
}

func (ddt *dropDatabaseTask) BeginTs() Timestamp {
	return ddt.Base.Timestamp
}

func (ddt *dropDatabaseTask) EndTs() Timestamp {
	return ddt.Base.Timestamp
}

func (ddt *dropDatabaseTask) SetTs(ts Timestamp) {
	ddt.Base.Timestamp = ts
}

func (ddt *dropDatabaseTask) OnEnqueue() error {
	if ddt.Base == nil {
		ddt.Base = commonpbutil.NewMsgBase()
	}
	ddt.Base.MsgType = commonpb.MsgType_DropDatabase
	ddt.Base.SourceID = paramtable.GetNodeID()
	return nil
}

func (ddt *dropDatabaseTask) PreExecute(ctx context.Context) error {
	return ValidateDatabaseName(ddt.GetDbName())
}

func (ddt *dropDatabaseTask) Execute(ctx context.Context) error {
	var err error
	ddt.result, err = ddt.rootCoord.DropDatabase(ctx, ddt.DropDatabaseRequest)

	if ddt.result != nil && ddt.result.ErrorCode == commonpb.ErrorCode_Success {
		globalMetaCache.RemoveDatabase(ctx, ddt.DbName)
		SendReplicateMessagePack(ctx, ddt.replicateMsgStream, ddt.DropDatabaseRequest)
	}
	return err
}

func (ddt *dropDatabaseTask) PostExecute(ctx context.Context) error {
	return nil
}

type listDatabaseTask struct {
	baseTask
	Condition
	*milvuspb.ListDatabasesRequest
	ctx       context.Context
	rootCoord types.RootCoordClient
	result    *milvuspb.ListDatabasesResponse
}

func (ldt *listDatabaseTask) TraceCtx() context.Context {
	return ldt.ctx
}

func (ldt *listDatabaseTask) ID() UniqueID {
	return ldt.Base.MsgID
}

func (ldt *listDatabaseTask) SetID(uid UniqueID) {
	ldt.Base.MsgID = uid
}

func (ldt *listDatabaseTask) Name() string {
	return ListDatabaseTaskName
}

func (ldt *listDatabaseTask) Type() commonpb.MsgType {
	return ldt.Base.MsgType
}

func (ldt *listDatabaseTask) BeginTs() Timestamp {
	return ldt.Base.Timestamp
}

func (ldt *listDatabaseTask) EndTs() Timestamp {
	return ldt.Base.Timestamp
}

func (ldt *listDatabaseTask) SetTs(ts Timestamp) {
	ldt.Base.Timestamp = ts
}

func (ldt *listDatabaseTask) OnEnqueue() error {
	ldt.Base = commonpbutil.NewMsgBase()
	ldt.Base.MsgType = commonpb.MsgType_ListDatabases
	ldt.Base.SourceID = paramtable.GetNodeID()
	return nil
}

func (ldt *listDatabaseTask) PreExecute(ctx context.Context) error {
	return nil
}

func (ldt *listDatabaseTask) Execute(ctx context.Context) error {
	var err error
	ctx = AppendUserInfoForRPC(ctx)
	ldt.result, err = ldt.rootCoord.ListDatabases(ctx, ldt.ListDatabasesRequest)
	return err
}

func (ldt *listDatabaseTask) PostExecute(ctx context.Context) error {
	return nil
}

type alterDatabaseTask struct {
	baseTask
	Condition
	*milvuspb.AlterDatabaseRequest
	ctx       context.Context
	rootCoord types.RootCoordClient
	result    *commonpb.Status
}

func (t *alterDatabaseTask) TraceCtx() context.Context {
	return t.ctx
}

func (t *alterDatabaseTask) ID() UniqueID {
	return t.Base.MsgID
}

func (t *alterDatabaseTask) SetID(uid UniqueID) {
	t.Base.MsgID = uid
}

func (t *alterDatabaseTask) Name() string {
	return AlterDatabaseTaskName
}

func (t *alterDatabaseTask) Type() commonpb.MsgType {
	return t.Base.MsgType
}

func (t *alterDatabaseTask) BeginTs() Timestamp {
	return t.Base.Timestamp
}

func (t *alterDatabaseTask) EndTs() Timestamp {
	return t.Base.Timestamp
}

func (t *alterDatabaseTask) SetTs(ts Timestamp) {
	t.Base.Timestamp = ts
}

func (t *alterDatabaseTask) OnEnqueue() error {
	if t.Base == nil {
		t.Base = commonpbutil.NewMsgBase()
	}
	return nil
}

func (t *alterDatabaseTask) PreExecute(ctx context.Context) error {
	t.Base.MsgType = commonpb.MsgType_AlterDatabase
	t.Base.SourceID = paramtable.GetNodeID()

	return nil
}

func (t *alterDatabaseTask) Execute(ctx context.Context) error {
	var err error

	req := &rootcoordpb.AlterDatabaseRequest{
		Base:       t.AlterDatabaseRequest.GetBase(),
		DbName:     t.AlterDatabaseRequest.GetDbName(),
		DbId:       t.AlterDatabaseRequest.GetDbId(),
		Properties: t.AlterDatabaseRequest.GetProperties(),
	}

	ret, err := t.rootCoord.AlterDatabase(ctx, req)
	if err != nil {
		log.Warn("AlterDatabase failed", zap.Error(err))
		return err
	}

	if err := merr.CheckRPCCall(t.result, err); err != nil {
		log.Warn("AlterDatabase failed", zap.Error(err))
		return err
	}

	t.result = ret

	return err
}

func (t *alterDatabaseTask) PostExecute(ctx context.Context) error {
	return nil
}

type describeDatabaseTask struct {
	baseTask
	Condition
	*milvuspb.DescribeDatabaseRequest
	ctx       context.Context
	rootCoord types.RootCoordClient
	result    *milvuspb.DescribeDatabaseResponse
}

func (t *describeDatabaseTask) TraceCtx() context.Context {
	return t.ctx
}

func (t *describeDatabaseTask) ID() UniqueID {
	return t.Base.MsgID
}

func (t *describeDatabaseTask) SetID(uid UniqueID) {
	t.Base.MsgID = uid
}

func (t *describeDatabaseTask) Name() string {
	return AlterDatabaseTaskName
}

func (t *describeDatabaseTask) Type() commonpb.MsgType {
	return t.Base.MsgType
}

func (t *describeDatabaseTask) BeginTs() Timestamp {
	return t.Base.Timestamp
}

func (t *describeDatabaseTask) EndTs() Timestamp {
	return t.Base.Timestamp
}

func (t *describeDatabaseTask) SetTs(ts Timestamp) {
	t.Base.Timestamp = ts
}

func (t *describeDatabaseTask) OnEnqueue() error {
	if t.Base == nil {
		t.Base = commonpbutil.NewMsgBase()
	}
	return nil
}

func (t *describeDatabaseTask) PreExecute(ctx context.Context) error {
	t.Base.MsgType = commonpb.MsgType_AlterCollection
	t.Base.SourceID = paramtable.GetNodeID()

	return nil
}

func (t *describeDatabaseTask) Execute(ctx context.Context) error {
	req := &rootcoordpb.DescribeDatabaseRequest{
		Base:   t.DescribeDatabaseRequest.GetBase(),
		DbName: t.DescribeDatabaseRequest.GetDbName(),
	}
	ret, err := t.rootCoord.DescribeDatabase(ctx, req)
	if err != nil {
		log.Warn("DescribeDatabase failed", zap.Error(err))
		return err
	}

	if err := merr.CheckRPCCall(ret, err); err != nil {
		log.Warn("DescribeDatabase failed", zap.Error(err))
		return err
	}

	t.result = &milvuspb.DescribeDatabaseResponse{
		Status:           ret.GetStatus(),
		DbName:           ret.GetDbName(),
		DbID:             ret.GetDbID(),
		CreatedTimestamp: ret.GetCreatedTimestamp(),
		Properties:       ret.GetProperties(),
	}
	return nil
}

func (t *describeDatabaseTask) PostExecute(ctx context.Context) error {
	return nil
}