mirror of https://github.com/milvus-io/milvus.git
enhance: [GoSDK] write back auto id value to row based input (#36964)
Related to #33460 --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/36981/head
parent
903c18ba26
commit
b7ffa8383c
|
@ -0,0 +1,87 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"math/rand"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
milvusclient "github.com/milvus-io/milvus/client/v2"
|
||||
"github.com/milvus-io/milvus/client/v2/row"
|
||||
)
|
||||
|
||||
type Data struct {
|
||||
ID int64 `milvus:"name:id;primary_key;auto_id"`
|
||||
Vector []float32 `milvus:"name:vector;dim:128"`
|
||||
}
|
||||
|
||||
const (
|
||||
milvusAddr = `localhost:19530`
|
||||
nEntities, dim = 10, 128
|
||||
collectionName = "hello_row_base"
|
||||
|
||||
msgFmt = "==== %s ====\n"
|
||||
idCol, randomCol, embeddingCol = "id", "random", "vector"
|
||||
topK = 3
|
||||
)
|
||||
|
||||
func main() {
|
||||
schema, err := row.ParseSchema(&Data{})
|
||||
if err != nil {
|
||||
log.Fatal("failed to parse schema from struct", err.Error())
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
log.Printf("Field name: %s, FieldType %s, IsPrimaryKey: %t", field.Name, field.DataType, field.PrimaryKey)
|
||||
}
|
||||
schema.WithName(collectionName)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
log.Printf(msgFmt, "start connecting to Milvus")
|
||||
c, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
|
||||
Address: milvusAddr,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal("failed to connect to milvus, err: ", err.Error())
|
||||
}
|
||||
defer c.Close(ctx)
|
||||
|
||||
if has, err := c.HasCollection(ctx, milvusclient.NewHasCollectionOption(collectionName)); err != nil {
|
||||
log.Fatal("failed to check collection exists or not", err.Error())
|
||||
} else if has {
|
||||
log.Printf("collection %s alread exists, dropping it now\n", collectionName)
|
||||
c.DropCollection(ctx, milvusclient.NewDropCollectionOption(collectionName))
|
||||
}
|
||||
|
||||
err = c.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema))
|
||||
if err != nil {
|
||||
log.Fatal("failed to create collection", err.Error())
|
||||
}
|
||||
|
||||
var rows []*Data
|
||||
for i := 0; i < nEntities; i++ {
|
||||
vec := make([]float32, 0, dim)
|
||||
for j := 0; j < dim; j++ {
|
||||
vec = append(vec, rand.Float32())
|
||||
}
|
||||
rows = append(rows, &Data{
|
||||
Vector: vec,
|
||||
})
|
||||
}
|
||||
|
||||
insertResult, err := c.Insert(ctx, milvusclient.NewRowBasedInsertOption(collectionName, lo.Map(rows, func(data *Data, _ int) any {
|
||||
return data
|
||||
})...))
|
||||
if err != nil {
|
||||
log.Fatal("failed to insert data: ", err.Error())
|
||||
}
|
||||
log.Println(insertResult.IDs)
|
||||
for _, row := range rows {
|
||||
// id shall be written back
|
||||
log.Println(row.ID)
|
||||
}
|
||||
|
||||
c.DropCollection(ctx, milvusclient.NewDropCollectionOption(collectionName))
|
||||
}
|
|
@ -269,6 +269,25 @@ func NewArrayColumn(f *entity.Field) column.Column {
|
|||
}
|
||||
}
|
||||
|
||||
func SetField(receiver any, fieldName string, value any) error {
|
||||
candidates, err := reflectValueCandi(reflect.ValueOf(receiver))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
candidate, ok := candidates[fieldName]
|
||||
// if field not found, just return
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if candidate.v.CanSet() {
|
||||
candidate.v.Set(reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type fieldCandi struct {
|
||||
name string
|
||||
v reflect.Value
|
||||
|
|
|
@ -56,7 +56,9 @@ func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ..
|
|||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
// write back pks if needed
|
||||
// pks values shall be written back to struct if receiver field exists
|
||||
return option.WriteBackPKs(collection.Schema, result.IDs)
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
type InsertOption interface {
|
||||
InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error)
|
||||
CollectionName() string
|
||||
WriteBackPKs(schema *entity.Schema, pks column.Column) error
|
||||
}
|
||||
|
||||
type UpsertOption interface {
|
||||
|
@ -52,6 +53,11 @@ type columnBasedDataOption struct {
|
|||
columns []column.Column
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WriteBackPKs(_ *entity.Schema, _ column.Column) error {
|
||||
// column based data option need not write back pk
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema, columns ...column.Column) ([]*schemapb.FieldData, int, error) {
|
||||
// setup dynamic related var
|
||||
isDynamic := colSchema.EnableDynamicField
|
||||
|
@ -296,6 +302,28 @@ func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (opt *rowBasedDataOption) WriteBackPKs(sch *entity.Schema, pks column.Column) error {
|
||||
pkField := sch.PKField()
|
||||
// not auto id, return
|
||||
if pkField == nil || !pkField.AutoID {
|
||||
return nil
|
||||
}
|
||||
if len(opt.rows) != pks.Len() {
|
||||
return errors.New("input row count is not equal to result pk length")
|
||||
}
|
||||
|
||||
for i, r := range opt.rows {
|
||||
// index range checked
|
||||
v, _ := pks.Get(i)
|
||||
err := row.SetField(r, pkField.Name, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type DeleteOption interface {
|
||||
Request() *milvuspb.DeleteRequest
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue