package proxy

import (
	"context"
	"strings"
	"testing"

	"github.com/cockroachdb/errors"
	"github.com/stretchr/testify/assert"
	"google.golang.org/grpc/metadata"

	"github.com/milvus-io/milvus/internal/mocks"
	"github.com/milvus-io/milvus/pkg/util"
	"github.com/milvus-io/milvus/pkg/util/crypto"
	"github.com/milvus-io/milvus/pkg/util/paramtable"
)

// validAuth validates the authentication
func TestValidAuth(t *testing.T) {
	validAuth := func(ctx context.Context, authorization []string) bool {
		if len(authorization) < 1 {
			return false
		}
		token := authorization[0]
		rawToken, _ := crypto.Base64Decode(token)
		username, password := parseMD(rawToken)
		if username == "" || password == "" {
			return false
		}
		return passwordVerify(ctx, username, password, globalMetaCache)
	}

	ctx := context.Background()
	// no metadata
	res := validAuth(ctx, nil)
	assert.False(t, res)
	// illegal metadata
	res = validAuth(ctx, []string{"xxx"})
	assert.False(t, res)
	// normal metadata
	rootCoord := &MockRootCoordClientInterface{}
	queryCoord := &mocks.MockQueryCoordClient{}
	mgr := newShardClientMgr()
	err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
	assert.NoError(t, err)
	res = validAuth(ctx, []string{crypto.Base64Encode("mockUser:mockPass")})
	assert.True(t, res)

	res = validAuth(ctx, []string{crypto.Base64Encode("mock")})
	assert.False(t, res)
}

func TestValidSourceID(t *testing.T) {
	ctx := context.Background()
	// no metadata
	res := validSourceID(ctx, nil)
	assert.False(t, res)
	// illegal metadata
	res = validSourceID(ctx, []string{"invalid_sourceid"})
	assert.False(t, res)
	// normal sourceId
	res = validSourceID(ctx, []string{crypto.Base64Encode(util.MemberCredID)})
	assert.True(t, res)
}

func TestAuthenticationInterceptor(t *testing.T) {
	ctx := context.Background()
	paramtable.Get().Save(Params.CommonCfg.AuthorizationEnabled.Key, "true") // mock authorization is turned on
	defer paramtable.Get().Reset(Params.CommonCfg.AuthorizationEnabled.Key)  // mock authorization is turned on
	// no metadata
	_, err := AuthenticationInterceptor(ctx)
	assert.Error(t, err)
	// mock metacache
	rootCoord := &MockRootCoordClientInterface{}
	queryCoord := &mocks.MockQueryCoordClient{}
	mgr := newShardClientMgr()
	err = InitMetaCache(ctx, rootCoord, queryCoord, mgr)
	assert.NoError(t, err)
	// with invalid metadata
	md := metadata.Pairs("xxx", "yyy")
	ctx = metadata.NewIncomingContext(ctx, md)
	_, err = AuthenticationInterceptor(ctx)
	assert.Error(t, err)
	// with valid username/password
	md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockUser:mockPass"))
	ctx = metadata.NewIncomingContext(ctx, md)
	_, err = AuthenticationInterceptor(ctx)
	assert.NoError(t, err)
	// with valid sourceId
	md = metadata.Pairs("sourceid", crypto.Base64Encode(util.MemberCredID))
	ctx = metadata.NewIncomingContext(ctx, md)
	_, err = AuthenticationInterceptor(ctx)
	assert.NoError(t, err)

	{
		// wrong authorization style
		md = metadata.Pairs(util.HeaderAuthorize, "123456")
		ctx = metadata.NewIncomingContext(ctx, md)
		_, err = AuthenticationInterceptor(ctx)
		assert.Error(t, err)
	}

	{
		// invalid user
		md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockUser2:mockPass"))
		ctx = metadata.NewIncomingContext(ctx, md)
		_, err = AuthenticationInterceptor(ctx)
		assert.Error(t, err)
	}

	{
		// default hook
		md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
		ctx = metadata.NewIncomingContext(ctx, md)
		_, err = AuthenticationInterceptor(ctx)
		assert.Error(t, err)
	}

	{
		// verify apikey error
		SetMockAPIHook("", errors.New("err"))
		md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
		ctx = metadata.NewIncomingContext(ctx, md)
		_, err = AuthenticationInterceptor(ctx)
		assert.Error(t, err)
	}

	{
		SetMockAPIHook("mockUser", nil)
		md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
		ctx = metadata.NewIncomingContext(ctx, md)
		authCtx, err := AuthenticationInterceptor(ctx)
		assert.NoError(t, err)
		md, ok := metadata.FromIncomingContext(authCtx)
		assert.True(t, ok)
		authStrArr := md[strings.ToLower(util.HeaderAuthorize)]
		token := authStrArr[0]
		rawToken, err := crypto.Base64Decode(token)
		assert.NoError(t, err)
		user, _ := parseMD(rawToken)
		assert.Equal(t, "mockUser", user)
	}
	hoo = defaultHook{}
}