mirror of https://github.com/go-gitea/gitea.git
Remove x/net/context vendor by using std package (#5202)
* Update dep github.com/markbates/goth * Update dep github.com/blevesearch/bleve * Update dep golang.org/x/oauth2 * Fix github.com/blevesearch/bleve to c74e08f039e56cef576e4336382b2a2d12d9e026 * Update dep golang.org/x/oauth2pull/4772/head^2
parent
b3000ae623
commit
4c1f1f9646
|
@ -90,7 +90,7 @@
|
||||||
revision = "3a771d992973f24aa725d07868b467d1ddfceafb"
|
revision = "3a771d992973f24aa725d07868b467d1ddfceafb"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:67351095005f164e748a5a21899d1403b03878cb2d40a7b0f742376e6eeda974"
|
digest = "1:c10f35be6200b09e26da267ca80f837315093ecaba27e7a223071380efb9dd32"
|
||||||
name = "github.com/blevesearch/bleve"
|
name = "github.com/blevesearch/bleve"
|
||||||
packages = [
|
packages = [
|
||||||
".",
|
".",
|
||||||
|
@ -135,7 +135,7 @@
|
||||||
"search/searcher",
|
"search/searcher",
|
||||||
]
|
]
|
||||||
pruneopts = "NUT"
|
pruneopts = "NUT"
|
||||||
revision = "ff210fbc6d348ad67aa5754eaea11a463fcddafd"
|
revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
|
@ -557,7 +557,7 @@
|
||||||
revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf"
|
revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:23f75ae90fcc38dac6fad6881006ea7d0f2c78db5f9f81f3df558dc91460e61f"
|
digest = "1:4b992ec853d0ea9bac3dcf09a64af61de1a392e6cb0eef2204c0c92f4ae6b911"
|
||||||
name = "github.com/markbates/goth"
|
name = "github.com/markbates/goth"
|
||||||
packages = [
|
packages = [
|
||||||
".",
|
".",
|
||||||
|
@ -572,8 +572,8 @@
|
||||||
"providers/twitter",
|
"providers/twitter",
|
||||||
]
|
]
|
||||||
pruneopts = "NUT"
|
pruneopts = "NUT"
|
||||||
revision = "f9c6649ab984d6ea71ef1e13b7b1cdffcf4592d3"
|
revision = "bc6d8ddf751a745f37ca5567dbbfc4157bbf5da9"
|
||||||
version = "v1.46.1"
|
version = "v1.47.2"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5"
|
digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5"
|
||||||
|
@ -809,10 +809,11 @@
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
digest = "1:6d5ed712653ea5321fe3e3475ab2188cf362a4e0d31e9fd3acbd4dfbbca0d680"
|
digest = "1:d0a0bdd2b64d981aa4e6a1ade90431d042cd7fa31b584e33d45e62cbfec43380"
|
||||||
name = "golang.org/x/net"
|
name = "golang.org/x/net"
|
||||||
packages = [
|
packages = [
|
||||||
"context",
|
"context",
|
||||||
|
"context/ctxhttp",
|
||||||
"html",
|
"html",
|
||||||
"html/atom",
|
"html/atom",
|
||||||
"html/charset",
|
"html/charset",
|
||||||
|
@ -821,14 +822,15 @@
|
||||||
revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344"
|
revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:8159a9cda4b8810aaaeb0d60e2fa68e2fd86d8af4ec8f5059830839e3c8d93d5"
|
branch = "master"
|
||||||
|
digest = "1:274a6321a5a9f185eeb3fab5d7d8397e0e9f57737490d749f562c7e205ffbc2e"
|
||||||
name = "golang.org/x/oauth2"
|
name = "golang.org/x/oauth2"
|
||||||
packages = [
|
packages = [
|
||||||
".",
|
".",
|
||||||
"internal",
|
"internal",
|
||||||
]
|
]
|
||||||
pruneopts = "NUT"
|
pruneopts = "NUT"
|
||||||
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061"
|
revision = "c453e0c757598fd055e170a3a359263c91e13153"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3"
|
digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3"
|
||||||
|
|
10
Gopkg.toml
10
Gopkg.toml
|
@ -14,6 +14,12 @@ ignored = ["google.golang.org/appengine*"]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
name = "code.gitea.io/sdk"
|
name = "code.gitea.io/sdk"
|
||||||
|
|
||||||
|
[[constraint]]
|
||||||
|
# branch = "master"
|
||||||
|
revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
|
||||||
|
name = "github.com/blevesearch/bleve"
|
||||||
|
#Not targetting v0.7.0 since standard where use only just after this tag
|
||||||
|
|
||||||
[[constraint]]
|
[[constraint]]
|
||||||
revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e"
|
revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e"
|
||||||
name = "golang.org/x/crypto"
|
name = "golang.org/x/crypto"
|
||||||
|
@ -61,7 +67,7 @@ ignored = ["google.golang.org/appengine*"]
|
||||||
|
|
||||||
[[constraint]]
|
[[constraint]]
|
||||||
name = "github.com/markbates/goth"
|
name = "github.com/markbates/goth"
|
||||||
version = "1.46.1"
|
version = "1.47.2"
|
||||||
|
|
||||||
[[constraint]]
|
[[constraint]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
|
@ -105,7 +111,7 @@ ignored = ["google.golang.org/appengine*"]
|
||||||
source = "github.com/go-gitea/bolt"
|
source = "github.com/go-gitea/bolt"
|
||||||
|
|
||||||
[[override]]
|
[[override]]
|
||||||
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061"
|
branch = "master"
|
||||||
name = "golang.org/x/oauth2"
|
name = "golang.org/x/oauth2"
|
||||||
|
|
||||||
[[constraint]]
|
[[constraint]]
|
||||||
|
|
|
@ -15,11 +15,12 @@
|
||||||
package bleve
|
package bleve
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/blevesearch/bleve/document"
|
"github.com/blevesearch/bleve/document"
|
||||||
"github.com/blevesearch/bleve/index"
|
"github.com/blevesearch/bleve/index"
|
||||||
"github.com/blevesearch/bleve/index/store"
|
"github.com/blevesearch/bleve/index/store"
|
||||||
"github.com/blevesearch/bleve/mapping"
|
"github.com/blevesearch/bleve/mapping"
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Batch groups together multiple Index and Delete
|
// A Batch groups together multiple Index and Delete
|
||||||
|
|
|
@ -100,8 +100,8 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error {
|
||||||
// prepare new index snapshot
|
// prepare new index snapshot
|
||||||
newSnapshot := &IndexSnapshot{
|
newSnapshot := &IndexSnapshot{
|
||||||
parent: s,
|
parent: s,
|
||||||
segment: make([]*SegmentSnapshot, nsegs, nsegs+1),
|
segment: make([]*SegmentSnapshot, 0, nsegs+1),
|
||||||
offsets: make([]uint64, nsegs, nsegs+1),
|
offsets: make([]uint64, 0, nsegs+1),
|
||||||
internal: make(map[string][]byte, len(s.root.internal)),
|
internal: make(map[string][]byte, len(s.root.internal)),
|
||||||
epoch: s.nextSnapshotEpoch,
|
epoch: s.nextSnapshotEpoch,
|
||||||
refs: 1,
|
refs: 1,
|
||||||
|
@ -124,24 +124,29 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
newSnapshot.segment[i] = &SegmentSnapshot{
|
|
||||||
|
newss := &SegmentSnapshot{
|
||||||
id: s.root.segment[i].id,
|
id: s.root.segment[i].id,
|
||||||
segment: s.root.segment[i].segment,
|
segment: s.root.segment[i].segment,
|
||||||
cachedDocs: s.root.segment[i].cachedDocs,
|
cachedDocs: s.root.segment[i].cachedDocs,
|
||||||
}
|
}
|
||||||
s.root.segment[i].segment.AddRef()
|
|
||||||
|
|
||||||
// apply new obsoletions
|
// apply new obsoletions
|
||||||
if s.root.segment[i].deleted == nil {
|
if s.root.segment[i].deleted == nil {
|
||||||
newSnapshot.segment[i].deleted = delta
|
newss.deleted = delta
|
||||||
} else {
|
} else {
|
||||||
newSnapshot.segment[i].deleted = roaring.Or(s.root.segment[i].deleted, delta)
|
newss.deleted = roaring.Or(s.root.segment[i].deleted, delta)
|
||||||
}
|
}
|
||||||
|
|
||||||
newSnapshot.offsets[i] = running
|
// check for live size before copying
|
||||||
running += s.root.segment[i].Count()
|
if newss.LiveSize() > 0 {
|
||||||
|
newSnapshot.segment = append(newSnapshot.segment, newss)
|
||||||
|
s.root.segment[i].segment.AddRef()
|
||||||
|
newSnapshot.offsets = append(newSnapshot.offsets, running)
|
||||||
|
running += s.root.segment[i].Count()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// append new segment, if any, to end of the new index snapshot
|
// append new segment, if any, to end of the new index snapshot
|
||||||
if next.data != nil {
|
if next.data != nil {
|
||||||
newSegmentSnapshot := &SegmentSnapshot{
|
newSegmentSnapshot := &SegmentSnapshot{
|
||||||
|
@ -193,6 +198,12 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
|
||||||
// prepare new index snapshot
|
// prepare new index snapshot
|
||||||
currSize := len(s.root.segment)
|
currSize := len(s.root.segment)
|
||||||
newSize := currSize + 1 - len(nextMerge.old)
|
newSize := currSize + 1 - len(nextMerge.old)
|
||||||
|
|
||||||
|
// empty segments deletion
|
||||||
|
if nextMerge.new == nil {
|
||||||
|
newSize--
|
||||||
|
}
|
||||||
|
|
||||||
newSnapshot := &IndexSnapshot{
|
newSnapshot := &IndexSnapshot{
|
||||||
parent: s,
|
parent: s,
|
||||||
segment: make([]*SegmentSnapshot, 0, newSize),
|
segment: make([]*SegmentSnapshot, 0, newSize),
|
||||||
|
@ -210,7 +221,7 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
|
||||||
segmentID := s.root.segment[i].id
|
segmentID := s.root.segment[i].id
|
||||||
if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok {
|
if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok {
|
||||||
// this segment is going away, see if anything else was deleted since we started the merge
|
// this segment is going away, see if anything else was deleted since we started the merge
|
||||||
if s.root.segment[i].deleted != nil {
|
if segSnapAtMerge != nil && s.root.segment[i].deleted != nil {
|
||||||
// assume all these deletes are new
|
// assume all these deletes are new
|
||||||
deletedSince := s.root.segment[i].deleted
|
deletedSince := s.root.segment[i].deleted
|
||||||
// if we already knew about some of them, remove
|
// if we already knew about some of them, remove
|
||||||
|
@ -224,7 +235,13 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
|
||||||
newSegmentDeleted.Add(uint32(newDocNum))
|
newSegmentDeleted.Add(uint32(newDocNum))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
// clean up the old segment map to figure out the
|
||||||
|
// obsolete segments wrt root in meantime, whatever
|
||||||
|
// segments left behind in old map after processing
|
||||||
|
// the root segments would be the obsolete segment set
|
||||||
|
delete(nextMerge.old, segmentID)
|
||||||
|
|
||||||
|
} else if s.root.segment[i].LiveSize() > 0 {
|
||||||
// this segment is staying
|
// this segment is staying
|
||||||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
|
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
|
||||||
id: s.root.segment[i].id,
|
id: s.root.segment[i].id,
|
||||||
|
@ -238,14 +255,35 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// put new segment at end
|
// before the newMerge introduction, need to clean the newly
|
||||||
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
|
// merged segment wrt the current root segments, hence
|
||||||
id: nextMerge.id,
|
// applying the obsolete segment contents to newly merged segment
|
||||||
segment: nextMerge.new, // take ownership for nextMerge.new's ref-count
|
for segID, ss := range nextMerge.old {
|
||||||
deleted: newSegmentDeleted,
|
obsoleted := ss.DocNumbersLive()
|
||||||
cachedDocs: &cachedDocs{cache: nil},
|
if obsoleted != nil {
|
||||||
})
|
obsoletedIter := obsoleted.Iterator()
|
||||||
newSnapshot.offsets = append(newSnapshot.offsets, running)
|
for obsoletedIter.HasNext() {
|
||||||
|
oldDocNum := obsoletedIter.Next()
|
||||||
|
newDocNum := nextMerge.oldNewDocNums[segID][oldDocNum]
|
||||||
|
newSegmentDeleted.Add(uint32(newDocNum))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// In case where all the docs in the newly merged segment getting
|
||||||
|
// deleted by the time we reach here, can skip the introduction.
|
||||||
|
if nextMerge.new != nil &&
|
||||||
|
nextMerge.new.Count() > newSegmentDeleted.GetCardinality() {
|
||||||
|
// put new segment at end
|
||||||
|
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
|
||||||
|
id: nextMerge.id,
|
||||||
|
segment: nextMerge.new, // take ownership for nextMerge.new's ref-count
|
||||||
|
deleted: newSegmentDeleted,
|
||||||
|
cachedDocs: &cachedDocs{cache: nil},
|
||||||
|
})
|
||||||
|
newSnapshot.offsets = append(newSnapshot.offsets, running)
|
||||||
|
}
|
||||||
|
|
||||||
|
newSnapshot.AddRef() // 1 ref for the nextMerge.notify response
|
||||||
|
|
||||||
// swap in new segment
|
// swap in new segment
|
||||||
rootPrev := s.root
|
rootPrev := s.root
|
||||||
|
@ -257,7 +295,8 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
|
||||||
_ = rootPrev.DecRef()
|
_ = rootPrev.DecRef()
|
||||||
}
|
}
|
||||||
|
|
||||||
// notify merger we incorporated this
|
// notify requester that we incorporated this
|
||||||
|
nextMerge.notify <- newSnapshot
|
||||||
close(nextMerge.notify)
|
close(nextMerge.notify)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
package scorch
|
package scorch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -28,6 +31,13 @@ import (
|
||||||
|
|
||||||
func (s *Scorch) mergerLoop() {
|
func (s *Scorch) mergerLoop() {
|
||||||
var lastEpochMergePlanned uint64
|
var lastEpochMergePlanned uint64
|
||||||
|
mergePlannerOptions, err := s.parseMergePlannerOptions()
|
||||||
|
if err != nil {
|
||||||
|
s.fireAsyncError(fmt.Errorf("mergePlannerOption json parsing err: %v", err))
|
||||||
|
s.asyncTasks.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
OUTER:
|
OUTER:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -45,7 +55,7 @@ OUTER:
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
// lets get started
|
// lets get started
|
||||||
err := s.planMergeAtSnapshot(ourSnapshot)
|
err := s.planMergeAtSnapshot(ourSnapshot, mergePlannerOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.fireAsyncError(fmt.Errorf("merging err: %v", err))
|
s.fireAsyncError(fmt.Errorf("merging err: %v", err))
|
||||||
_ = ourSnapshot.DecRef()
|
_ = ourSnapshot.DecRef()
|
||||||
|
@ -58,51 +68,49 @@ OUTER:
|
||||||
_ = ourSnapshot.DecRef()
|
_ = ourSnapshot.DecRef()
|
||||||
|
|
||||||
// tell the persister we're waiting for changes
|
// tell the persister we're waiting for changes
|
||||||
// first make a notification chan
|
// first make a epochWatcher chan
|
||||||
notifyUs := make(notificationChan)
|
ew := &epochWatcher{
|
||||||
|
epoch: lastEpochMergePlanned,
|
||||||
|
notifyCh: make(notificationChan, 1),
|
||||||
|
}
|
||||||
|
|
||||||
// give it to the persister
|
// give it to the persister
|
||||||
select {
|
select {
|
||||||
case <-s.closeCh:
|
case <-s.closeCh:
|
||||||
break OUTER
|
break OUTER
|
||||||
case s.persisterNotifier <- notifyUs:
|
case s.persisterNotifier <- ew:
|
||||||
}
|
}
|
||||||
|
|
||||||
// check again
|
// now wait for persister (but also detect close)
|
||||||
s.rootLock.RLock()
|
|
||||||
ourSnapshot = s.root
|
|
||||||
ourSnapshot.AddRef()
|
|
||||||
s.rootLock.RUnlock()
|
|
||||||
|
|
||||||
if ourSnapshot.epoch != lastEpochMergePlanned {
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
// lets get started
|
|
||||||
err := s.planMergeAtSnapshot(ourSnapshot)
|
|
||||||
if err != nil {
|
|
||||||
s.fireAsyncError(fmt.Errorf("merging err: %v", err))
|
|
||||||
_ = ourSnapshot.DecRef()
|
|
||||||
continue OUTER
|
|
||||||
}
|
|
||||||
lastEpochMergePlanned = ourSnapshot.epoch
|
|
||||||
|
|
||||||
s.fireEvent(EventKindMergerProgress, time.Since(startTime))
|
|
||||||
}
|
|
||||||
_ = ourSnapshot.DecRef()
|
|
||||||
|
|
||||||
// now wait for it (but also detect close)
|
|
||||||
select {
|
select {
|
||||||
case <-s.closeCh:
|
case <-s.closeCh:
|
||||||
break OUTER
|
break OUTER
|
||||||
case <-notifyUs:
|
case <-ew.notifyCh:
|
||||||
// woken up, next loop should pick up work
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.asyncTasks.Done()
|
s.asyncTasks.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
|
func (s *Scorch) parseMergePlannerOptions() (*mergeplan.MergePlanOptions,
|
||||||
|
error) {
|
||||||
|
mergePlannerOptions := mergeplan.DefaultMergePlanOptions
|
||||||
|
if v, ok := s.config["scorchMergePlanOptions"]; ok {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return &mergePlannerOptions, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(b, &mergePlannerOptions)
|
||||||
|
if err != nil {
|
||||||
|
return &mergePlannerOptions, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &mergePlannerOptions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot,
|
||||||
|
options *mergeplan.MergePlanOptions) error {
|
||||||
// build list of zap segments in this snapshot
|
// build list of zap segments in this snapshot
|
||||||
var onlyZapSnapshots []mergeplan.Segment
|
var onlyZapSnapshots []mergeplan.Segment
|
||||||
for _, segmentSnapshot := range ourSnapshot.segment {
|
for _, segmentSnapshot := range ourSnapshot.segment {
|
||||||
|
@ -112,7 +120,7 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// give this list to the planner
|
// give this list to the planner
|
||||||
resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, nil)
|
resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("merge planning err: %v", err)
|
return fmt.Errorf("merge planning err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -122,8 +130,12 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// process tasks in serial for now
|
// process tasks in serial for now
|
||||||
var notifications []notificationChan
|
var notifications []chan *IndexSnapshot
|
||||||
for _, task := range resultMergePlan.Tasks {
|
for _, task := range resultMergePlan.Tasks {
|
||||||
|
if len(task.Segments) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
oldMap := make(map[uint64]*SegmentSnapshot)
|
oldMap := make(map[uint64]*SegmentSnapshot)
|
||||||
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
|
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
|
||||||
segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments))
|
segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments))
|
||||||
|
@ -132,40 +144,51 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
|
||||||
if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok {
|
if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok {
|
||||||
oldMap[segSnapshot.id] = segSnapshot
|
oldMap[segSnapshot.id] = segSnapshot
|
||||||
if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok {
|
if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok {
|
||||||
segmentsToMerge = append(segmentsToMerge, zapSeg)
|
if segSnapshot.LiveSize() == 0 {
|
||||||
docsToDrop = append(docsToDrop, segSnapshot.deleted)
|
oldMap[segSnapshot.id] = nil
|
||||||
|
} else {
|
||||||
|
segmentsToMerge = append(segmentsToMerge, zapSeg)
|
||||||
|
docsToDrop = append(docsToDrop, segSnapshot.deleted)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
filename := zapFileName(newSegmentID)
|
var oldNewDocNums map[uint64][]uint64
|
||||||
s.markIneligibleForRemoval(filename)
|
var segment segment.Segment
|
||||||
path := s.path + string(os.PathSeparator) + filename
|
if len(segmentsToMerge) > 0 {
|
||||||
newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, DefaultChunkFactor)
|
filename := zapFileName(newSegmentID)
|
||||||
if err != nil {
|
s.markIneligibleForRemoval(filename)
|
||||||
s.unmarkIneligibleForRemoval(filename)
|
path := s.path + string(os.PathSeparator) + filename
|
||||||
return fmt.Errorf("merging failed: %v", err)
|
newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, 1024)
|
||||||
}
|
if err != nil {
|
||||||
segment, err := zap.Open(path)
|
s.unmarkIneligibleForRemoval(filename)
|
||||||
if err != nil {
|
return fmt.Errorf("merging failed: %v", err)
|
||||||
s.unmarkIneligibleForRemoval(filename)
|
}
|
||||||
return err
|
segment, err = zap.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
s.unmarkIneligibleForRemoval(filename)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
oldNewDocNums = make(map[uint64][]uint64)
|
||||||
|
for i, segNewDocNums := range newDocNums {
|
||||||
|
oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := &segmentMerge{
|
sm := &segmentMerge{
|
||||||
id: newSegmentID,
|
id: newSegmentID,
|
||||||
old: oldMap,
|
old: oldMap,
|
||||||
oldNewDocNums: make(map[uint64][]uint64),
|
oldNewDocNums: oldNewDocNums,
|
||||||
new: segment,
|
new: segment,
|
||||||
notify: make(notificationChan),
|
notify: make(chan *IndexSnapshot, 1),
|
||||||
}
|
}
|
||||||
notifications = append(notifications, sm.notify)
|
notifications = append(notifications, sm.notify)
|
||||||
for i, segNewDocNums := range newDocNums {
|
|
||||||
sm.oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
|
|
||||||
}
|
|
||||||
|
|
||||||
// give it to the introducer
|
// give it to the introducer
|
||||||
select {
|
select {
|
||||||
case <-s.closeCh:
|
case <-s.closeCh:
|
||||||
|
_ = segment.Close()
|
||||||
return nil
|
return nil
|
||||||
case s.merges <- sm:
|
case s.merges <- sm:
|
||||||
}
|
}
|
||||||
|
@ -174,7 +197,10 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
|
||||||
select {
|
select {
|
||||||
case <-s.closeCh:
|
case <-s.closeCh:
|
||||||
return nil
|
return nil
|
||||||
case <-notification:
|
case newSnapshot := <-notification:
|
||||||
|
if newSnapshot != nil {
|
||||||
|
_ = newSnapshot.DecRef()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -185,5 +211,72 @@ type segmentMerge struct {
|
||||||
old map[uint64]*SegmentSnapshot
|
old map[uint64]*SegmentSnapshot
|
||||||
oldNewDocNums map[uint64][]uint64
|
oldNewDocNums map[uint64][]uint64
|
||||||
new segment.Segment
|
new segment.Segment
|
||||||
notify notificationChan
|
notify chan *IndexSnapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
// perform a merging of the given SegmentBase instances into a new,
|
||||||
|
// persisted segment, and synchronously introduce that new segment
|
||||||
|
// into the root
|
||||||
|
func (s *Scorch) mergeSegmentBases(snapshot *IndexSnapshot,
|
||||||
|
sbs []*zap.SegmentBase, sbsDrops []*roaring.Bitmap, sbsIndexes []int,
|
||||||
|
chunkFactor uint32) (uint64, *IndexSnapshot, uint64, error) {
|
||||||
|
var br bytes.Buffer
|
||||||
|
|
||||||
|
cr := zap.NewCountHashWriter(&br)
|
||||||
|
|
||||||
|
newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset,
|
||||||
|
docValueOffset, dictLocs, fieldsInv, fieldsMap, err :=
|
||||||
|
zap.MergeToWriter(sbs, sbsDrops, chunkFactor, cr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sb, err := zap.InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
|
||||||
|
fieldsMap, fieldsInv, numDocs, storedIndexOffset, fieldsIndexOffset,
|
||||||
|
docValueOffset, dictLocs)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
|
||||||
|
|
||||||
|
filename := zapFileName(newSegmentID)
|
||||||
|
path := s.path + string(os.PathSeparator) + filename
|
||||||
|
err = zap.PersistSegmentBase(sb, path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
segment, err := zap.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sm := &segmentMerge{
|
||||||
|
id: newSegmentID,
|
||||||
|
old: make(map[uint64]*SegmentSnapshot),
|
||||||
|
oldNewDocNums: make(map[uint64][]uint64),
|
||||||
|
new: segment,
|
||||||
|
notify: make(chan *IndexSnapshot, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, idx := range sbsIndexes {
|
||||||
|
ss := snapshot.segment[idx]
|
||||||
|
sm.old[ss.id] = ss
|
||||||
|
sm.oldNewDocNums[ss.id] = newDocNums[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
select { // send to introducer
|
||||||
|
case <-s.closeCh:
|
||||||
|
_ = segment.DecRef()
|
||||||
|
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
|
||||||
|
case s.merges <- sm:
|
||||||
|
}
|
||||||
|
|
||||||
|
select { // wait for introduction to complete
|
||||||
|
case <-s.closeCh:
|
||||||
|
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
|
||||||
|
case newSnapshot := <-sm.notify:
|
||||||
|
return numDocs, newSnapshot, newSegmentID, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -186,13 +186,13 @@ func plan(segmentsIn []Segment, o *MergePlanOptions) (*MergePlan, error) {
|
||||||
|
|
||||||
// While we’re over budget, keep looping, which might produce
|
// While we’re over budget, keep looping, which might produce
|
||||||
// another MergeTask.
|
// another MergeTask.
|
||||||
for len(eligibles) > budgetNumSegments {
|
for len(eligibles) > 0 && (len(eligibles)+len(rv.Tasks)) > budgetNumSegments {
|
||||||
// Track a current best roster as we examine and score
|
// Track a current best roster as we examine and score
|
||||||
// potential rosters of merges.
|
// potential rosters of merges.
|
||||||
var bestRoster []Segment
|
var bestRoster []Segment
|
||||||
var bestRosterScore float64 // Lower score is better.
|
var bestRosterScore float64 // Lower score is better.
|
||||||
|
|
||||||
for startIdx := 0; startIdx < len(eligibles)-o.SegmentsPerMergeTask; startIdx++ {
|
for startIdx := 0; startIdx < len(eligibles); startIdx++ {
|
||||||
var roster []Segment
|
var roster []Segment
|
||||||
var rosterLiveSize int64
|
var rosterLiveSize int64
|
||||||
|
|
||||||
|
|
|
@ -34,22 +34,39 @@ import (
|
||||||
|
|
||||||
var DefaultChunkFactor uint32 = 1024
|
var DefaultChunkFactor uint32 = 1024
|
||||||
|
|
||||||
|
// Arbitrary number, need to make it configurable.
|
||||||
|
// Lower values like 10/making persister really slow
|
||||||
|
// doesn't work well as it is creating more files to
|
||||||
|
// persist for in next persist iteration and spikes the # FDs.
|
||||||
|
// Ideal value should let persister also proceed at
|
||||||
|
// an optimum pace so that the merger can skip
|
||||||
|
// many intermediate snapshots.
|
||||||
|
// This needs to be based on empirical data.
|
||||||
|
// TODO - may need to revisit this approach/value.
|
||||||
|
var epochDistance = uint64(5)
|
||||||
|
|
||||||
type notificationChan chan struct{}
|
type notificationChan chan struct{}
|
||||||
|
|
||||||
func (s *Scorch) persisterLoop() {
|
func (s *Scorch) persisterLoop() {
|
||||||
defer s.asyncTasks.Done()
|
defer s.asyncTasks.Done()
|
||||||
|
|
||||||
var notifyChs []notificationChan
|
var persistWatchers []*epochWatcher
|
||||||
var lastPersistedEpoch uint64
|
var lastPersistedEpoch, lastMergedEpoch uint64
|
||||||
|
var ew *epochWatcher
|
||||||
OUTER:
|
OUTER:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-s.closeCh:
|
case <-s.closeCh:
|
||||||
break OUTER
|
break OUTER
|
||||||
case notifyCh := <-s.persisterNotifier:
|
case ew = <-s.persisterNotifier:
|
||||||
notifyChs = append(notifyChs, notifyCh)
|
persistWatchers = append(persistWatchers, ew)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
if ew != nil && ew.epoch > lastMergedEpoch {
|
||||||
|
lastMergedEpoch = ew.epoch
|
||||||
|
}
|
||||||
|
persistWatchers = s.pausePersisterForMergerCatchUp(lastPersistedEpoch,
|
||||||
|
&lastMergedEpoch, persistWatchers)
|
||||||
|
|
||||||
var ourSnapshot *IndexSnapshot
|
var ourSnapshot *IndexSnapshot
|
||||||
var ourPersisted []chan error
|
var ourPersisted []chan error
|
||||||
|
@ -81,10 +98,11 @@ OUTER:
|
||||||
}
|
}
|
||||||
|
|
||||||
lastPersistedEpoch = ourSnapshot.epoch
|
lastPersistedEpoch = ourSnapshot.epoch
|
||||||
for _, notifyCh := range notifyChs {
|
for _, ew := range persistWatchers {
|
||||||
close(notifyCh)
|
close(ew.notifyCh)
|
||||||
}
|
}
|
||||||
notifyChs = nil
|
|
||||||
|
persistWatchers = nil
|
||||||
_ = ourSnapshot.DecRef()
|
_ = ourSnapshot.DecRef()
|
||||||
|
|
||||||
changed := false
|
changed := false
|
||||||
|
@ -120,27 +138,155 @@ OUTER:
|
||||||
break OUTER
|
break OUTER
|
||||||
case <-w.notifyCh:
|
case <-w.notifyCh:
|
||||||
// woken up, next loop should pick up work
|
// woken up, next loop should pick up work
|
||||||
|
continue OUTER
|
||||||
|
case ew = <-s.persisterNotifier:
|
||||||
|
// if the watchers are already caught up then let them wait,
|
||||||
|
// else let them continue to do the catch up
|
||||||
|
persistWatchers = append(persistWatchers, ew)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func notifyMergeWatchers(lastPersistedEpoch uint64,
|
||||||
|
persistWatchers []*epochWatcher) []*epochWatcher {
|
||||||
|
var watchersNext []*epochWatcher
|
||||||
|
for _, w := range persistWatchers {
|
||||||
|
if w.epoch < lastPersistedEpoch {
|
||||||
|
close(w.notifyCh)
|
||||||
|
} else {
|
||||||
|
watchersNext = append(watchersNext, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return watchersNext
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scorch) pausePersisterForMergerCatchUp(lastPersistedEpoch uint64, lastMergedEpoch *uint64,
|
||||||
|
persistWatchers []*epochWatcher) []*epochWatcher {
|
||||||
|
|
||||||
|
// first, let the watchers proceed if they lag behind
|
||||||
|
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
|
||||||
|
|
||||||
|
OUTER:
|
||||||
|
// check for slow merger and await until the merger catch up
|
||||||
|
for lastPersistedEpoch > *lastMergedEpoch+epochDistance {
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-s.closeCh:
|
||||||
|
break OUTER
|
||||||
|
case ew := <-s.persisterNotifier:
|
||||||
|
persistWatchers = append(persistWatchers, ew)
|
||||||
|
*lastMergedEpoch = ew.epoch
|
||||||
|
}
|
||||||
|
|
||||||
|
// let the watchers proceed if they lag behind
|
||||||
|
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
|
||||||
|
}
|
||||||
|
|
||||||
|
return persistWatchers
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
|
func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
|
||||||
|
persisted, err := s.persistSnapshotMaybeMerge(snapshot)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if persisted {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.persistSnapshotDirect(snapshot)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultMinSegmentsForInMemoryMerge represents the default number of
|
||||||
|
// in-memory zap segments that persistSnapshotMaybeMerge() needs to
|
||||||
|
// see in an IndexSnapshot before it decides to merge and persist
|
||||||
|
// those segments
|
||||||
|
var DefaultMinSegmentsForInMemoryMerge = 2
|
||||||
|
|
||||||
|
// persistSnapshotMaybeMerge examines the snapshot and might merge and
|
||||||
|
// persist the in-memory zap segments if there are enough of them
|
||||||
|
func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot) (
|
||||||
|
bool, error) {
|
||||||
|
// collect the in-memory zap segments (SegmentBase instances)
|
||||||
|
var sbs []*zap.SegmentBase
|
||||||
|
var sbsDrops []*roaring.Bitmap
|
||||||
|
var sbsIndexes []int
|
||||||
|
|
||||||
|
for i, segmentSnapshot := range snapshot.segment {
|
||||||
|
if sb, ok := segmentSnapshot.segment.(*zap.SegmentBase); ok {
|
||||||
|
sbs = append(sbs, sb)
|
||||||
|
sbsDrops = append(sbsDrops, segmentSnapshot.deleted)
|
||||||
|
sbsIndexes = append(sbsIndexes, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sbs) < DefaultMinSegmentsForInMemoryMerge {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, newSnapshot, newSegmentID, err := s.mergeSegmentBases(
|
||||||
|
snapshot, sbs, sbsDrops, sbsIndexes, DefaultChunkFactor)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if newSnapshot == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = newSnapshot.DecRef()
|
||||||
|
}()
|
||||||
|
|
||||||
|
mergedSegmentIDs := map[uint64]struct{}{}
|
||||||
|
for _, idx := range sbsIndexes {
|
||||||
|
mergedSegmentIDs[snapshot.segment[idx].id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// construct a snapshot that's logically equivalent to the input
|
||||||
|
// snapshot, but with merged segments replaced by the new segment
|
||||||
|
equiv := &IndexSnapshot{
|
||||||
|
parent: snapshot.parent,
|
||||||
|
segment: make([]*SegmentSnapshot, 0, len(snapshot.segment)),
|
||||||
|
internal: snapshot.internal,
|
||||||
|
epoch: snapshot.epoch,
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy to the equiv the segments that weren't replaced
|
||||||
|
for _, segment := range snapshot.segment {
|
||||||
|
if _, wasMerged := mergedSegmentIDs[segment.id]; !wasMerged {
|
||||||
|
equiv.segment = append(equiv.segment, segment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// append to the equiv the new segment
|
||||||
|
for _, segment := range newSnapshot.segment {
|
||||||
|
if segment.id == newSegmentID {
|
||||||
|
equiv.segment = append(equiv.segment, &SegmentSnapshot{
|
||||||
|
id: newSegmentID,
|
||||||
|
segment: segment.segment,
|
||||||
|
deleted: nil, // nil since merging handled deletions
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.persistSnapshotDirect(equiv)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot) (err error) {
|
||||||
// start a write transaction
|
// start a write transaction
|
||||||
tx, err := s.rootBolt.Begin(true)
|
tx, err := s.rootBolt.Begin(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// defer fsync of the rootbolt
|
// defer rollback on error
|
||||||
defer func() {
|
defer func() {
|
||||||
if err == nil {
|
if err != nil {
|
||||||
err = s.rootBolt.Sync()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
// defer commit/rollback transaction
|
|
||||||
defer func() {
|
|
||||||
if err == nil {
|
|
||||||
err = tx.Commit()
|
|
||||||
} else {
|
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -172,20 +318,20 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
|
||||||
newSegmentPaths := make(map[uint64]string)
|
newSegmentPaths := make(map[uint64]string)
|
||||||
|
|
||||||
// first ensure that each segment in this snapshot has been persisted
|
// first ensure that each segment in this snapshot has been persisted
|
||||||
for i, segmentSnapshot := range snapshot.segment {
|
for _, segmentSnapshot := range snapshot.segment {
|
||||||
snapshotSegmentKey := segment.EncodeUvarintAscending(nil, uint64(i))
|
snapshotSegmentKey := segment.EncodeUvarintAscending(nil, segmentSnapshot.id)
|
||||||
snapshotSegmentBucket, err2 := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey)
|
snapshotSegmentBucket, err := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey)
|
||||||
if err2 != nil {
|
if err != nil {
|
||||||
return err2
|
return err
|
||||||
}
|
}
|
||||||
switch seg := segmentSnapshot.segment.(type) {
|
switch seg := segmentSnapshot.segment.(type) {
|
||||||
case *zap.SegmentBase:
|
case *zap.SegmentBase:
|
||||||
// need to persist this to disk
|
// need to persist this to disk
|
||||||
filename := zapFileName(segmentSnapshot.id)
|
filename := zapFileName(segmentSnapshot.id)
|
||||||
path := s.path + string(os.PathSeparator) + filename
|
path := s.path + string(os.PathSeparator) + filename
|
||||||
err2 := zap.PersistSegmentBase(seg, path)
|
err = zap.PersistSegmentBase(seg, path)
|
||||||
if err2 != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error persisting segment: %v", err2)
|
return fmt.Errorf("error persisting segment: %v", err)
|
||||||
}
|
}
|
||||||
newSegmentPaths[segmentSnapshot.id] = path
|
newSegmentPaths[segmentSnapshot.id] = path
|
||||||
err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename))
|
err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename))
|
||||||
|
@ -218,19 +364,28 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only alter the root if we actually persisted a segment
|
// we need to swap in a new root only when we've persisted 1 or
|
||||||
// (sometimes its just a new snapshot, possibly with new internal values)
|
// more segments -- whereby the new root would have 1-for-1
|
||||||
|
// replacements of in-memory segments with file-based segments
|
||||||
|
//
|
||||||
|
// other cases like updates to internal values only, and/or when
|
||||||
|
// there are only deletions, are already covered and persisted by
|
||||||
|
// the newly populated boltdb snapshotBucket above
|
||||||
if len(newSegmentPaths) > 0 {
|
if len(newSegmentPaths) > 0 {
|
||||||
// now try to open all the new snapshots
|
// now try to open all the new snapshots
|
||||||
newSegments := make(map[uint64]segment.Segment)
|
newSegments := make(map[uint64]segment.Segment)
|
||||||
|
defer func() {
|
||||||
|
for _, s := range newSegments {
|
||||||
|
if s != nil {
|
||||||
|
// cleanup segments that were opened but not
|
||||||
|
// swapped into the new root
|
||||||
|
_ = s.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
for segmentID, path := range newSegmentPaths {
|
for segmentID, path := range newSegmentPaths {
|
||||||
newSegments[segmentID], err = zap.Open(path)
|
newSegments[segmentID], err = zap.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, s := range newSegments {
|
|
||||||
if s != nil {
|
|
||||||
_ = s.Close() // cleanup segments that were successfully opened
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("error opening new segment at %s, %v", path, err)
|
return fmt.Errorf("error opening new segment at %s, %v", path, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -255,6 +410,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
|
||||||
cachedDocs: segmentSnapshot.cachedDocs,
|
cachedDocs: segmentSnapshot.cachedDocs,
|
||||||
}
|
}
|
||||||
newIndexSnapshot.segment[i] = newSegmentSnapshot
|
newIndexSnapshot.segment[i] = newSegmentSnapshot
|
||||||
|
delete(newSegments, segmentSnapshot.id)
|
||||||
// update items persisted incase of a new segment snapshot
|
// update items persisted incase of a new segment snapshot
|
||||||
atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count())
|
atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count())
|
||||||
} else {
|
} else {
|
||||||
|
@ -266,9 +422,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
|
||||||
for k, v := range s.root.internal {
|
for k, v := range s.root.internal {
|
||||||
newIndexSnapshot.internal[k] = v
|
newIndexSnapshot.internal[k] = v
|
||||||
}
|
}
|
||||||
for _, filename := range filenames {
|
|
||||||
delete(s.ineligibleForRemoval, filename)
|
|
||||||
}
|
|
||||||
rootPrev := s.root
|
rootPrev := s.root
|
||||||
s.root = newIndexSnapshot
|
s.root = newIndexSnapshot
|
||||||
s.rootLock.Unlock()
|
s.rootLock.Unlock()
|
||||||
|
@ -277,6 +431,24 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.rootBolt.Sync()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow files to become eligible for removal after commit, such
|
||||||
|
// as file segments from snapshots that came from the merger
|
||||||
|
s.rootLock.Lock()
|
||||||
|
for _, filename := range filenames {
|
||||||
|
delete(s.ineligibleForRemoval, filename)
|
||||||
|
}
|
||||||
|
s.rootLock.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ type Scorch struct {
|
||||||
merges chan *segmentMerge
|
merges chan *segmentMerge
|
||||||
introducerNotifier chan *epochWatcher
|
introducerNotifier chan *epochWatcher
|
||||||
revertToSnapshots chan *snapshotReversion
|
revertToSnapshots chan *snapshotReversion
|
||||||
persisterNotifier chan notificationChan
|
persisterNotifier chan *epochWatcher
|
||||||
rootBolt *bolt.DB
|
rootBolt *bolt.DB
|
||||||
asyncTasks sync.WaitGroup
|
asyncTasks sync.WaitGroup
|
||||||
|
|
||||||
|
@ -114,6 +114,25 @@ func (s *Scorch) fireAsyncError(err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scorch) Open() error {
|
func (s *Scorch) Open() error {
|
||||||
|
err := s.openBolt()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.asyncTasks.Add(1)
|
||||||
|
go s.mainLoop()
|
||||||
|
|
||||||
|
if !s.readOnly && s.path != "" {
|
||||||
|
s.asyncTasks.Add(1)
|
||||||
|
go s.persisterLoop()
|
||||||
|
s.asyncTasks.Add(1)
|
||||||
|
go s.mergerLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scorch) openBolt() error {
|
||||||
var ok bool
|
var ok bool
|
||||||
s.path, ok = s.config["path"].(string)
|
s.path, ok = s.config["path"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -136,6 +155,7 @@ func (s *Scorch) Open() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt"
|
rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt"
|
||||||
var err error
|
var err error
|
||||||
if s.path != "" {
|
if s.path != "" {
|
||||||
|
@ -156,7 +176,7 @@ func (s *Scorch) Open() error {
|
||||||
s.merges = make(chan *segmentMerge)
|
s.merges = make(chan *segmentMerge)
|
||||||
s.introducerNotifier = make(chan *epochWatcher, 1)
|
s.introducerNotifier = make(chan *epochWatcher, 1)
|
||||||
s.revertToSnapshots = make(chan *snapshotReversion)
|
s.revertToSnapshots = make(chan *snapshotReversion)
|
||||||
s.persisterNotifier = make(chan notificationChan)
|
s.persisterNotifier = make(chan *epochWatcher, 1)
|
||||||
|
|
||||||
if !s.readOnly && s.path != "" {
|
if !s.readOnly && s.path != "" {
|
||||||
err := s.removeOldZapFiles() // Before persister or merger create any new files.
|
err := s.removeOldZapFiles() // Before persister or merger create any new files.
|
||||||
|
@ -166,16 +186,6 @@ func (s *Scorch) Open() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.asyncTasks.Add(1)
|
|
||||||
go s.mainLoop()
|
|
||||||
|
|
||||||
if !s.readOnly && s.path != "" {
|
|
||||||
s.asyncTasks.Add(1)
|
|
||||||
go s.persisterLoop()
|
|
||||||
s.asyncTasks.Add(1)
|
|
||||||
go s.mergerLoop()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,17 +320,21 @@ func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string,
|
||||||
introduction.persisted = make(chan error, 1)
|
introduction.persisted = make(chan error, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// get read lock, to optimistically prepare obsoleted info
|
// optimistically prepare obsoletes outside of rootLock
|
||||||
s.rootLock.RLock()
|
s.rootLock.RLock()
|
||||||
for _, seg := range s.root.segment {
|
root := s.root
|
||||||
|
root.AddRef()
|
||||||
|
s.rootLock.RUnlock()
|
||||||
|
|
||||||
|
for _, seg := range root.segment {
|
||||||
delta, err := seg.segment.DocNumbers(ids)
|
delta, err := seg.segment.DocNumbers(ids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.rootLock.RUnlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
introduction.obsoletes[seg.id] = delta
|
introduction.obsoletes[seg.id] = delta
|
||||||
}
|
}
|
||||||
s.rootLock.RUnlock()
|
|
||||||
|
_ = root.DecRef()
|
||||||
|
|
||||||
s.introductions <- introduction
|
s.introductions <- introduction
|
||||||
|
|
||||||
|
|
|
@ -95,6 +95,21 @@ func (s *Segment) initializeDict(results []*index.AnalysisResult) {
|
||||||
var numTokenFrequencies int
|
var numTokenFrequencies int
|
||||||
var totLocs int
|
var totLocs int
|
||||||
|
|
||||||
|
// initial scan for all fieldID's to sort them
|
||||||
|
for _, result := range results {
|
||||||
|
for _, field := range result.Document.CompositeFields {
|
||||||
|
s.getOrDefineField(field.Name())
|
||||||
|
}
|
||||||
|
for _, field := range result.Document.Fields {
|
||||||
|
s.getOrDefineField(field.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(s.FieldsInv[1:]) // keep _id as first field
|
||||||
|
s.FieldsMap = make(map[string]uint16, len(s.FieldsInv))
|
||||||
|
for fieldID, fieldName := range s.FieldsInv {
|
||||||
|
s.FieldsMap[fieldName] = uint16(fieldID + 1)
|
||||||
|
}
|
||||||
|
|
||||||
processField := func(fieldID uint16, tfs analysis.TokenFrequencies) {
|
processField := func(fieldID uint16, tfs analysis.TokenFrequencies) {
|
||||||
for term, tf := range tfs {
|
for term, tf := range tfs {
|
||||||
pidPlus1, exists := s.Dicts[fieldID][term]
|
pidPlus1, exists := s.Dicts[fieldID][term]
|
||||||
|
|
|
@ -76,6 +76,8 @@ type DictionaryIterator struct {
|
||||||
prefix string
|
prefix string
|
||||||
end string
|
end string
|
||||||
offset int
|
offset int
|
||||||
|
|
||||||
|
dictEntry index.DictEntry // reused across Next()'s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next returns the next entry in the dictionary
|
// Next returns the next entry in the dictionary
|
||||||
|
@ -95,8 +97,7 @@ func (d *DictionaryIterator) Next() (*index.DictEntry, error) {
|
||||||
|
|
||||||
d.offset++
|
d.offset++
|
||||||
postingID := d.d.segment.Dicts[d.d.fieldID][next]
|
postingID := d.d.segment.Dicts[d.d.fieldID][next]
|
||||||
return &index.DictEntry{
|
d.dictEntry.Term = next
|
||||||
Term: next,
|
d.dictEntry.Count = d.d.segment.Postings[postingID-1].GetCardinality()
|
||||||
Count: d.d.segment.Postings[postingID-1].GetCardinality(),
|
return &d.dictEntry, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ import (
|
||||||
"github.com/golang/snappy"
|
"github.com/golang/snappy"
|
||||||
)
|
)
|
||||||
|
|
||||||
const version uint32 = 2
|
const version uint32 = 3
|
||||||
|
|
||||||
const fieldNotUninverted = math.MaxUint64
|
const fieldNotUninverted = math.MaxUint64
|
||||||
|
|
||||||
|
@ -187,79 +187,42 @@ func persistBase(memSegment *mem.Segment, cr *CountHashWriter, chunkFactor uint3
|
||||||
}
|
}
|
||||||
|
|
||||||
func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) {
|
func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) {
|
||||||
|
|
||||||
var curr int
|
var curr int
|
||||||
var metaBuf bytes.Buffer
|
var metaBuf bytes.Buffer
|
||||||
var data, compressed []byte
|
var data, compressed []byte
|
||||||
|
|
||||||
|
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
|
||||||
|
|
||||||
docNumOffsets := make(map[int]uint64, len(memSegment.Stored))
|
docNumOffsets := make(map[int]uint64, len(memSegment.Stored))
|
||||||
|
|
||||||
for docNum, storedValues := range memSegment.Stored {
|
for docNum, storedValues := range memSegment.Stored {
|
||||||
if docNum != 0 {
|
if docNum != 0 {
|
||||||
// reset buffer if necessary
|
// reset buffer if necessary
|
||||||
|
curr = 0
|
||||||
metaBuf.Reset()
|
metaBuf.Reset()
|
||||||
data = data[:0]
|
data = data[:0]
|
||||||
compressed = compressed[:0]
|
compressed = compressed[:0]
|
||||||
curr = 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
|
|
||||||
|
|
||||||
st := memSegment.StoredTypes[docNum]
|
st := memSegment.StoredTypes[docNum]
|
||||||
sp := memSegment.StoredPos[docNum]
|
sp := memSegment.StoredPos[docNum]
|
||||||
|
|
||||||
// encode fields in order
|
// encode fields in order
|
||||||
for fieldID := range memSegment.FieldsInv {
|
for fieldID := range memSegment.FieldsInv {
|
||||||
if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok {
|
if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok {
|
||||||
// has stored values for this field
|
|
||||||
num := len(storedFieldValues)
|
|
||||||
|
|
||||||
stf := st[uint16(fieldID)]
|
stf := st[uint16(fieldID)]
|
||||||
spf := sp[uint16(fieldID)]
|
spf := sp[uint16(fieldID)]
|
||||||
|
|
||||||
// process each value
|
var err2 error
|
||||||
for i := 0; i < num; i++ {
|
curr, data, err2 = persistStoredFieldValues(fieldID,
|
||||||
// encode field
|
storedFieldValues, stf, spf, curr, metaEncoder, data)
|
||||||
_, err2 := metaEncoder.PutU64(uint64(fieldID))
|
if err2 != nil {
|
||||||
if err2 != nil {
|
return 0, err2
|
||||||
return 0, err2
|
|
||||||
}
|
|
||||||
// encode type
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(stf[i]))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, err2
|
|
||||||
}
|
|
||||||
// encode start offset
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(curr))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, err2
|
|
||||||
}
|
|
||||||
// end len
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, err2
|
|
||||||
}
|
|
||||||
// encode number of array pos
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(len(spf[i])))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, err2
|
|
||||||
}
|
|
||||||
// encode all array positions
|
|
||||||
for _, pos := range spf[i] {
|
|
||||||
_, err2 = metaEncoder.PutU64(pos)
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, err2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// append data
|
|
||||||
data = append(data, storedFieldValues[i]...)
|
|
||||||
// update curr
|
|
||||||
curr += len(storedFieldValues[i])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
metaEncoder.Close()
|
|
||||||
|
|
||||||
|
metaEncoder.Close()
|
||||||
metaBytes := metaBuf.Bytes()
|
metaBytes := metaBuf.Bytes()
|
||||||
|
|
||||||
// compress the data
|
// compress the data
|
||||||
|
@ -299,6 +262,51 @@ func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error)
|
||||||
return rv, nil
|
return rv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func persistStoredFieldValues(fieldID int,
|
||||||
|
storedFieldValues [][]byte, stf []byte, spf [][]uint64,
|
||||||
|
curr int, metaEncoder *govarint.Base128Encoder, data []byte) (
|
||||||
|
int, []byte, error) {
|
||||||
|
for i := 0; i < len(storedFieldValues); i++ {
|
||||||
|
// encode field
|
||||||
|
_, err := metaEncoder.PutU64(uint64(fieldID))
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
// encode type
|
||||||
|
_, err = metaEncoder.PutU64(uint64(stf[i]))
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
// encode start offset
|
||||||
|
_, err = metaEncoder.PutU64(uint64(curr))
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
// end len
|
||||||
|
_, err = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
// encode number of array pos
|
||||||
|
_, err = metaEncoder.PutU64(uint64(len(spf[i])))
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
// encode all array positions
|
||||||
|
for _, pos := range spf[i] {
|
||||||
|
_, err = metaEncoder.PutU64(pos)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data = append(data, storedFieldValues[i]...)
|
||||||
|
curr += len(storedFieldValues[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return curr, data, nil
|
||||||
|
}
|
||||||
|
|
||||||
func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) {
|
func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) {
|
||||||
var freqOffsets, locOfffsets []uint64
|
var freqOffsets, locOfffsets []uint64
|
||||||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1))
|
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1))
|
||||||
|
@ -580,7 +588,7 @@ func persistDocValues(memSegment *mem.Segment, w *CountHashWriter,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// resetting encoder for the next field
|
// reseting encoder for the next field
|
||||||
fdvEncoder.Reset()
|
fdvEncoder.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -625,12 +633,21 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase,
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
|
||||||
|
memSegment.FieldsMap, memSegment.FieldsInv, numDocs,
|
||||||
|
storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitSegmentBase(mem []byte, memCRC uint32, chunkFactor uint32,
|
||||||
|
fieldsMap map[string]uint16, fieldsInv []string, numDocs uint64,
|
||||||
|
storedIndexOffset uint64, fieldsIndexOffset uint64, docValueOffset uint64,
|
||||||
|
dictLocs []uint64) (*SegmentBase, error) {
|
||||||
sb := &SegmentBase{
|
sb := &SegmentBase{
|
||||||
mem: br.Bytes(),
|
mem: mem,
|
||||||
memCRC: cr.Sum32(),
|
memCRC: memCRC,
|
||||||
chunkFactor: chunkFactor,
|
chunkFactor: chunkFactor,
|
||||||
fieldsMap: memSegment.FieldsMap,
|
fieldsMap: fieldsMap,
|
||||||
fieldsInv: memSegment.FieldsInv,
|
fieldsInv: fieldsInv,
|
||||||
numDocs: numDocs,
|
numDocs: numDocs,
|
||||||
storedIndexOffset: storedIndexOffset,
|
storedIndexOffset: storedIndexOffset,
|
||||||
fieldsIndexOffset: fieldsIndexOffset,
|
fieldsIndexOffset: fieldsIndexOffset,
|
||||||
|
@ -639,7 +656,7 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase,
|
||||||
fieldDvIterMap: make(map[uint16]*docValueIterator),
|
fieldDvIterMap: make(map[uint16]*docValueIterator),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sb.loadDvIterators()
|
err := sb.loadDvIterators()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ type chunkedContentCoder struct {
|
||||||
// MetaData represents the data information inside a
|
// MetaData represents the data information inside a
|
||||||
// chunk.
|
// chunk.
|
||||||
type MetaData struct {
|
type MetaData struct {
|
||||||
DocID uint64 // docid of the data inside the chunk
|
DocNum uint64 // docNum of the data inside the chunk
|
||||||
DocDvLoc uint64 // starting offset for a given docid
|
DocDvLoc uint64 // starting offset for a given docid
|
||||||
DocDvLen uint64 // length of data inside the chunk for the given docid
|
DocDvLen uint64 // length of data inside the chunk for the given docid
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ func newChunkedContentCoder(chunkSize uint64,
|
||||||
rv := &chunkedContentCoder{
|
rv := &chunkedContentCoder{
|
||||||
chunkSize: chunkSize,
|
chunkSize: chunkSize,
|
||||||
chunkLens: make([]uint64, total),
|
chunkLens: make([]uint64, total),
|
||||||
chunkMeta: []MetaData{},
|
chunkMeta: make([]MetaData, 0, total),
|
||||||
}
|
}
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
|
@ -68,7 +68,7 @@ func (c *chunkedContentCoder) Reset() {
|
||||||
for i := range c.chunkLens {
|
for i := range c.chunkLens {
|
||||||
c.chunkLens[i] = 0
|
c.chunkLens[i] = 0
|
||||||
}
|
}
|
||||||
c.chunkMeta = []MetaData{}
|
c.chunkMeta = c.chunkMeta[:0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close indicates you are done calling Add() this allows
|
// Close indicates you are done calling Add() this allows
|
||||||
|
@ -88,7 +88,7 @@ func (c *chunkedContentCoder) flushContents() error {
|
||||||
|
|
||||||
// write out the metaData slice
|
// write out the metaData slice
|
||||||
for _, meta := range c.chunkMeta {
|
for _, meta := range c.chunkMeta {
|
||||||
_, err := writeUvarints(&c.chunkMetaBuf, meta.DocID, meta.DocDvLoc, meta.DocDvLen)
|
_, err := writeUvarints(&c.chunkMetaBuf, meta.DocNum, meta.DocDvLoc, meta.DocDvLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error {
|
||||||
// clearing the chunk specific meta for next chunk
|
// clearing the chunk specific meta for next chunk
|
||||||
c.chunkBuf.Reset()
|
c.chunkBuf.Reset()
|
||||||
c.chunkMetaBuf.Reset()
|
c.chunkMetaBuf.Reset()
|
||||||
c.chunkMeta = []MetaData{}
|
c.chunkMeta = c.chunkMeta[:0]
|
||||||
c.currChunk = chunk
|
c.currChunk = chunk
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.chunkMeta = append(c.chunkMeta, MetaData{
|
c.chunkMeta = append(c.chunkMeta, MetaData{
|
||||||
DocID: docNum,
|
DocNum: docNum,
|
||||||
DocDvLoc: uint64(dvOffset),
|
DocDvLoc: uint64(dvOffset),
|
||||||
DocDvLen: uint64(dvSize),
|
DocDvLen: uint64(dvSize),
|
||||||
})
|
})
|
||||||
|
|
|
@ -34,32 +34,47 @@ type Dictionary struct {
|
||||||
|
|
||||||
// PostingsList returns the postings list for the specified term
|
// PostingsList returns the postings list for the specified term
|
||||||
func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) {
|
func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) {
|
||||||
return d.postingsList([]byte(term), except)
|
return d.postingsList([]byte(term), except, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap) (*PostingsList, error) {
|
func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
|
||||||
rv := &PostingsList{
|
if d.fst == nil {
|
||||||
sb: d.sb,
|
return d.postingsListInit(rv, except), nil
|
||||||
term: term,
|
|
||||||
except: except,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.fst != nil {
|
postingsOffset, exists, err := d.fst.Get(term)
|
||||||
postingsOffset, exists, err := d.fst.Get(term)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("vellum err: %v", err)
|
||||||
return nil, fmt.Errorf("vellum err: %v", err)
|
}
|
||||||
}
|
if !exists {
|
||||||
if exists {
|
return d.postingsListInit(rv, except), nil
|
||||||
err = rv.read(postingsOffset, d)
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return d.postingsListFromOffset(postingsOffset, except, rv)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
func (d *Dictionary) postingsListFromOffset(postingsOffset uint64, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
|
||||||
|
rv = d.postingsListInit(rv, except)
|
||||||
|
|
||||||
|
err := rv.read(postingsOffset, d)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return rv, nil
|
return rv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Dictionary) postingsListInit(rv *PostingsList, except *roaring.Bitmap) *PostingsList {
|
||||||
|
if rv == nil {
|
||||||
|
rv = &PostingsList{}
|
||||||
|
} else {
|
||||||
|
*rv = PostingsList{} // clear the struct
|
||||||
|
}
|
||||||
|
rv.sb = d.sb
|
||||||
|
rv.except = except
|
||||||
|
return rv
|
||||||
|
}
|
||||||
|
|
||||||
// Iterator returns an iterator for this dictionary
|
// Iterator returns an iterator for this dictionary
|
||||||
func (d *Dictionary) Iterator() segment.DictionaryIterator {
|
func (d *Dictionary) Iterator() segment.DictionaryIterator {
|
||||||
rv := &DictionaryIterator{
|
rv := &DictionaryIterator{
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (s *SegmentBase) loadFieldDocValueIterator(field string,
|
||||||
func (di *docValueIterator) loadDvChunk(chunkNumber,
|
func (di *docValueIterator) loadDvChunk(chunkNumber,
|
||||||
localDocNum uint64, s *SegmentBase) error {
|
localDocNum uint64, s *SegmentBase) error {
|
||||||
// advance to the chunk where the docValues
|
// advance to the chunk where the docValues
|
||||||
// reside for the given docID
|
// reside for the given docNum
|
||||||
destChunkDataLoc := di.dvDataLoc
|
destChunkDataLoc := di.dvDataLoc
|
||||||
for i := 0; i < int(chunkNumber); i++ {
|
for i := 0; i < int(chunkNumber); i++ {
|
||||||
destChunkDataLoc += di.chunkLens[i]
|
destChunkDataLoc += di.chunkLens[i]
|
||||||
|
@ -116,7 +116,7 @@ func (di *docValueIterator) loadDvChunk(chunkNumber,
|
||||||
offset := uint64(0)
|
offset := uint64(0)
|
||||||
di.curChunkHeader = make([]MetaData, int(numDocs))
|
di.curChunkHeader = make([]MetaData, int(numDocs))
|
||||||
for i := 0; i < int(numDocs); i++ {
|
for i := 0; i < int(numDocs); i++ {
|
||||||
di.curChunkHeader[i].DocID, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
|
di.curChunkHeader[i].DocNum, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
|
||||||
offset += uint64(read)
|
offset += uint64(read)
|
||||||
di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
|
di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
|
||||||
offset += uint64(read)
|
offset += uint64(read)
|
||||||
|
@ -131,10 +131,10 @@ func (di *docValueIterator) loadDvChunk(chunkNumber,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (di *docValueIterator) visitDocValues(docID uint64,
|
func (di *docValueIterator) visitDocValues(docNum uint64,
|
||||||
visitor index.DocumentFieldTermVisitor) error {
|
visitor index.DocumentFieldTermVisitor) error {
|
||||||
// binary search the term locations for the docID
|
// binary search the term locations for the docNum
|
||||||
start, length := di.getDocValueLocs(docID)
|
start, length := di.getDocValueLocs(docNum)
|
||||||
if start == math.MaxUint64 || length == math.MaxUint64 {
|
if start == math.MaxUint64 || length == math.MaxUint64 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func (di *docValueIterator) visitDocValues(docID uint64,
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// pick the terms for the given docID
|
// pick the terms for the given docNum
|
||||||
uncompressed = uncompressed[start : start+length]
|
uncompressed = uncompressed[start : start+length]
|
||||||
for {
|
for {
|
||||||
i := bytes.Index(uncompressed, termSeparatorSplitSlice)
|
i := bytes.Index(uncompressed, termSeparatorSplitSlice)
|
||||||
|
@ -159,11 +159,11 @@ func (di *docValueIterator) visitDocValues(docID uint64,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (di *docValueIterator) getDocValueLocs(docID uint64) (uint64, uint64) {
|
func (di *docValueIterator) getDocValueLocs(docNum uint64) (uint64, uint64) {
|
||||||
i := sort.Search(len(di.curChunkHeader), func(i int) bool {
|
i := sort.Search(len(di.curChunkHeader), func(i int) bool {
|
||||||
return di.curChunkHeader[i].DocID >= docID
|
return di.curChunkHeader[i].DocNum >= docNum
|
||||||
})
|
})
|
||||||
if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocID == docID {
|
if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocNum == docNum {
|
||||||
return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen
|
return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen
|
||||||
}
|
}
|
||||||
return math.MaxUint64, math.MaxUint64
|
return math.MaxUint64, math.MaxUint64
|
||||||
|
|
124
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/enumerator.go
generated
vendored
Normal file
124
vendor/github.com/blevesearch/bleve/index/scorch/segment/zap/enumerator.go
generated
vendored
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
// Copyright (c) 2018 Couchbase, Inc.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package zap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/couchbase/vellum"
|
||||||
|
)
|
||||||
|
|
||||||
|
// enumerator provides an ordered traversal of multiple vellum
|
||||||
|
// iterators. Like JOIN of iterators, the enumerator produces a
|
||||||
|
// sequence of (key, iteratorIndex, value) tuples, sorted by key ASC,
|
||||||
|
// then iteratorIndex ASC, where the same key might be seen or
|
||||||
|
// repeated across multiple child iterators.
|
||||||
|
type enumerator struct {
|
||||||
|
itrs []vellum.Iterator
|
||||||
|
currKs [][]byte
|
||||||
|
currVs []uint64
|
||||||
|
|
||||||
|
lowK []byte
|
||||||
|
lowIdxs []int
|
||||||
|
lowCurr int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newEnumerator returns a new enumerator over the vellum Iterators
|
||||||
|
func newEnumerator(itrs []vellum.Iterator) (*enumerator, error) {
|
||||||
|
rv := &enumerator{
|
||||||
|
itrs: itrs,
|
||||||
|
currKs: make([][]byte, len(itrs)),
|
||||||
|
currVs: make([]uint64, len(itrs)),
|
||||||
|
lowIdxs: make([]int, 0, len(itrs)),
|
||||||
|
}
|
||||||
|
for i, itr := range rv.itrs {
|
||||||
|
rv.currKs[i], rv.currVs[i] = itr.Current()
|
||||||
|
}
|
||||||
|
rv.updateMatches()
|
||||||
|
if rv.lowK == nil {
|
||||||
|
return rv, vellum.ErrIteratorDone
|
||||||
|
}
|
||||||
|
return rv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateMatches maintains the low key matches based on the currKs
|
||||||
|
func (m *enumerator) updateMatches() {
|
||||||
|
m.lowK = nil
|
||||||
|
m.lowIdxs = m.lowIdxs[:0]
|
||||||
|
m.lowCurr = 0
|
||||||
|
|
||||||
|
for i, key := range m.currKs {
|
||||||
|
if key == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cmp := bytes.Compare(key, m.lowK)
|
||||||
|
if cmp < 0 || m.lowK == nil {
|
||||||
|
// reached a new low
|
||||||
|
m.lowK = key
|
||||||
|
m.lowIdxs = m.lowIdxs[:0]
|
||||||
|
m.lowIdxs = append(m.lowIdxs, i)
|
||||||
|
} else if cmp == 0 {
|
||||||
|
m.lowIdxs = append(m.lowIdxs, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current returns the enumerator's current key, iterator-index, and
|
||||||
|
// value. If the enumerator is not pointing at a valid value (because
|
||||||
|
// Next returned an error previously), Current will return nil,0,0.
|
||||||
|
func (m *enumerator) Current() ([]byte, int, uint64) {
|
||||||
|
var i int
|
||||||
|
var v uint64
|
||||||
|
if m.lowCurr < len(m.lowIdxs) {
|
||||||
|
i = m.lowIdxs[m.lowCurr]
|
||||||
|
v = m.currVs[i]
|
||||||
|
}
|
||||||
|
return m.lowK, i, v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next advances the enumerator to the next key/iterator/value result,
|
||||||
|
// else vellum.ErrIteratorDone is returned.
|
||||||
|
func (m *enumerator) Next() error {
|
||||||
|
m.lowCurr += 1
|
||||||
|
if m.lowCurr >= len(m.lowIdxs) {
|
||||||
|
// move all the current low iterators forwards
|
||||||
|
for _, vi := range m.lowIdxs {
|
||||||
|
err := m.itrs[vi].Next()
|
||||||
|
if err != nil && err != vellum.ErrIteratorDone {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.currKs[vi], m.currVs[vi] = m.itrs[vi].Current()
|
||||||
|
}
|
||||||
|
m.updateMatches()
|
||||||
|
}
|
||||||
|
if m.lowK == nil {
|
||||||
|
return vellum.ErrIteratorDone
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close all the underlying Iterators. The first error, if any, will
|
||||||
|
// be returned.
|
||||||
|
func (m *enumerator) Close() error {
|
||||||
|
var rv error
|
||||||
|
for _, itr := range m.itrs {
|
||||||
|
err := itr.Close()
|
||||||
|
if rv == nil {
|
||||||
|
rv = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rv
|
||||||
|
}
|
|
@ -30,6 +30,8 @@ type chunkedIntCoder struct {
|
||||||
encoder *govarint.Base128Encoder
|
encoder *govarint.Base128Encoder
|
||||||
chunkLens []uint64
|
chunkLens []uint64
|
||||||
currChunk uint64
|
currChunk uint64
|
||||||
|
|
||||||
|
buf []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// newChunkedIntCoder returns a new chunk int coder which packs data into
|
// newChunkedIntCoder returns a new chunk int coder which packs data into
|
||||||
|
@ -67,12 +69,8 @@ func (c *chunkedIntCoder) Add(docNum uint64, vals ...uint64) error {
|
||||||
// starting a new chunk
|
// starting a new chunk
|
||||||
if c.encoder != nil {
|
if c.encoder != nil {
|
||||||
// close out last
|
// close out last
|
||||||
c.encoder.Close()
|
c.Close()
|
||||||
encodingBytes := c.chunkBuf.Bytes()
|
|
||||||
c.chunkLens[c.currChunk] = uint64(len(encodingBytes))
|
|
||||||
c.final = append(c.final, encodingBytes...)
|
|
||||||
c.chunkBuf.Reset()
|
c.chunkBuf.Reset()
|
||||||
c.encoder = govarint.NewU64Base128Encoder(&c.chunkBuf)
|
|
||||||
}
|
}
|
||||||
c.currChunk = chunk
|
c.currChunk = chunk
|
||||||
}
|
}
|
||||||
|
@ -98,26 +96,25 @@ func (c *chunkedIntCoder) Close() {
|
||||||
|
|
||||||
// Write commits all the encoded chunked integers to the provided writer.
|
// Write commits all the encoded chunked integers to the provided writer.
|
||||||
func (c *chunkedIntCoder) Write(w io.Writer) (int, error) {
|
func (c *chunkedIntCoder) Write(w io.Writer) (int, error) {
|
||||||
var tw int
|
bufNeeded := binary.MaxVarintLen64 * (1 + len(c.chunkLens))
|
||||||
buf := make([]byte, binary.MaxVarintLen64)
|
if len(c.buf) < bufNeeded {
|
||||||
// write out the number of chunks
|
c.buf = make([]byte, bufNeeded)
|
||||||
|
}
|
||||||
|
buf := c.buf
|
||||||
|
|
||||||
|
// write out the number of chunks & each chunkLen
|
||||||
n := binary.PutUvarint(buf, uint64(len(c.chunkLens)))
|
n := binary.PutUvarint(buf, uint64(len(c.chunkLens)))
|
||||||
nw, err := w.Write(buf[:n])
|
for _, chunkLen := range c.chunkLens {
|
||||||
tw += nw
|
n += binary.PutUvarint(buf[n:], uint64(chunkLen))
|
||||||
|
}
|
||||||
|
|
||||||
|
tw, err := w.Write(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tw, err
|
return tw, err
|
||||||
}
|
}
|
||||||
// write out the chunk lens
|
|
||||||
for _, chunkLen := range c.chunkLens {
|
|
||||||
n := binary.PutUvarint(buf, uint64(chunkLen))
|
|
||||||
nw, err = w.Write(buf[:n])
|
|
||||||
tw += nw
|
|
||||||
if err != nil {
|
|
||||||
return tw, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// write out the data
|
// write out the data
|
||||||
nw, err = w.Write(c.final)
|
nw, err := w.Write(c.final)
|
||||||
tw += nw
|
tw += nw
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tw, err
|
return tw, err
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/RoaringBitmap/roaring"
|
"github.com/RoaringBitmap/roaring"
|
||||||
"github.com/Smerity/govarint"
|
"github.com/Smerity/govarint"
|
||||||
|
@ -28,6 +29,8 @@ import (
|
||||||
"github.com/golang/snappy"
|
"github.com/golang/snappy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const docDropped = math.MaxUint64 // sentinel docNum to represent a deleted doc
|
||||||
|
|
||||||
// Merge takes a slice of zap segments and bit masks describing which
|
// Merge takes a slice of zap segments and bit masks describing which
|
||||||
// documents may be dropped, and creates a new segment containing the
|
// documents may be dropped, and creates a new segment containing the
|
||||||
// remaining data. This new segment is built at the specified path,
|
// remaining data. This new segment is built at the specified path,
|
||||||
|
@ -46,47 +49,26 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string,
|
||||||
_ = os.Remove(path)
|
_ = os.Remove(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
segmentBases := make([]*SegmentBase, len(segments))
|
||||||
|
for segmenti, segment := range segments {
|
||||||
|
segmentBases[segmenti] = &segment.SegmentBase
|
||||||
|
}
|
||||||
|
|
||||||
// buffer the output
|
// buffer the output
|
||||||
br := bufio.NewWriter(f)
|
br := bufio.NewWriter(f)
|
||||||
|
|
||||||
// wrap it for counting (tracking offsets)
|
// wrap it for counting (tracking offsets)
|
||||||
cr := NewCountHashWriter(br)
|
cr := NewCountHashWriter(br)
|
||||||
|
|
||||||
fieldsInv := mergeFields(segments)
|
newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, _, _, _, err :=
|
||||||
fieldsMap := mapFields(fieldsInv)
|
MergeToWriter(segmentBases, drops, chunkFactor, cr)
|
||||||
|
|
||||||
var newDocNums [][]uint64
|
|
||||||
var storedIndexOffset uint64
|
|
||||||
fieldDvLocsOffset := uint64(fieldNotUninverted)
|
|
||||||
var dictLocs []uint64
|
|
||||||
|
|
||||||
newSegDocCount := computeNewDocCount(segments, drops)
|
|
||||||
if newSegDocCount > 0 {
|
|
||||||
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
|
|
||||||
fieldsMap, fieldsInv, newSegDocCount, cr)
|
|
||||||
if err != nil {
|
|
||||||
cleanup()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dictLocs, fieldDvLocsOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap,
|
|
||||||
newDocNums, newSegDocCount, chunkFactor, cr)
|
|
||||||
if err != nil {
|
|
||||||
cleanup()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dictLocs = make([]uint64, len(fieldsInv))
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldsIndexOffset, err := persistFields(fieldsInv, cr, dictLocs)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanup()
|
cleanup()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = persistFooter(newSegDocCount, storedIndexOffset,
|
err = persistFooter(numDocs, storedIndexOffset, fieldsIndexOffset,
|
||||||
fieldsIndexOffset, fieldDvLocsOffset, chunkFactor, cr.Sum32(), cr)
|
docValueOffset, chunkFactor, cr.Sum32(), cr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanup()
|
cleanup()
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -113,21 +95,59 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string,
|
||||||
return newDocNums, nil
|
return newDocNums, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapFields takes the fieldsInv list and builds the map
|
func MergeToWriter(segments []*SegmentBase, drops []*roaring.Bitmap,
|
||||||
|
chunkFactor uint32, cr *CountHashWriter) (
|
||||||
|
newDocNums [][]uint64,
|
||||||
|
numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset uint64,
|
||||||
|
dictLocs []uint64, fieldsInv []string, fieldsMap map[string]uint16,
|
||||||
|
err error) {
|
||||||
|
docValueOffset = uint64(fieldNotUninverted)
|
||||||
|
|
||||||
|
var fieldsSame bool
|
||||||
|
fieldsSame, fieldsInv = mergeFields(segments)
|
||||||
|
fieldsMap = mapFields(fieldsInv)
|
||||||
|
|
||||||
|
numDocs = computeNewDocCount(segments, drops)
|
||||||
|
if numDocs > 0 {
|
||||||
|
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
|
||||||
|
fieldsMap, fieldsInv, fieldsSame, numDocs, cr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, 0, 0, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dictLocs, docValueOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap,
|
||||||
|
newDocNums, numDocs, chunkFactor, cr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, 0, 0, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dictLocs = make([]uint64, len(fieldsInv))
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldsIndexOffset, err = persistFields(fieldsInv, cr, dictLocs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, 0, 0, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs, fieldsInv, fieldsMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapFields takes the fieldsInv list and returns a map of fieldName
|
||||||
|
// to fieldID+1
|
||||||
func mapFields(fields []string) map[string]uint16 {
|
func mapFields(fields []string) map[string]uint16 {
|
||||||
rv := make(map[string]uint16, len(fields))
|
rv := make(map[string]uint16, len(fields))
|
||||||
for i, fieldName := range fields {
|
for i, fieldName := range fields {
|
||||||
rv[fieldName] = uint16(i)
|
rv[fieldName] = uint16(i) + 1
|
||||||
}
|
}
|
||||||
return rv
|
return rv
|
||||||
}
|
}
|
||||||
|
|
||||||
// computeNewDocCount determines how many documents will be in the newly
|
// computeNewDocCount determines how many documents will be in the newly
|
||||||
// merged segment when obsoleted docs are dropped
|
// merged segment when obsoleted docs are dropped
|
||||||
func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 {
|
func computeNewDocCount(segments []*SegmentBase, drops []*roaring.Bitmap) uint64 {
|
||||||
var newDocCount uint64
|
var newDocCount uint64
|
||||||
for segI, segment := range segments {
|
for segI, segment := range segments {
|
||||||
newDocCount += segment.NumDocs()
|
newDocCount += segment.numDocs
|
||||||
if drops[segI] != nil {
|
if drops[segI] != nil {
|
||||||
newDocCount -= drops[segI].GetCardinality()
|
newDocCount -= drops[segI].GetCardinality()
|
||||||
}
|
}
|
||||||
|
@ -135,8 +155,8 @@ func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 {
|
||||||
return newDocCount
|
return newDocCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
|
func persistMergedRest(segments []*SegmentBase, dropsIn []*roaring.Bitmap,
|
||||||
fieldsInv []string, fieldsMap map[string]uint16, newDocNums [][]uint64,
|
fieldsInv []string, fieldsMap map[string]uint16, newDocNumsIn [][]uint64,
|
||||||
newSegDocCount uint64, chunkFactor uint32,
|
newSegDocCount uint64, chunkFactor uint32,
|
||||||
w *CountHashWriter) ([]uint64, uint64, error) {
|
w *CountHashWriter) ([]uint64, uint64, error) {
|
||||||
|
|
||||||
|
@ -144,9 +164,14 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64)
|
var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64)
|
||||||
var bufLoc []uint64
|
var bufLoc []uint64
|
||||||
|
|
||||||
|
var postings *PostingsList
|
||||||
|
var postItr *PostingsIterator
|
||||||
|
|
||||||
rv := make([]uint64, len(fieldsInv))
|
rv := make([]uint64, len(fieldsInv))
|
||||||
fieldDvLocs := make([]uint64, len(fieldsInv))
|
fieldDvLocs := make([]uint64, len(fieldsInv))
|
||||||
fieldDvLocsOffset := uint64(fieldNotUninverted)
|
|
||||||
|
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
|
||||||
|
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
|
||||||
|
|
||||||
// docTermMap is keyed by docNum, where the array impl provides
|
// docTermMap is keyed by docNum, where the array impl provides
|
||||||
// better memory usage behavior than a sparse-friendlier hashmap
|
// better memory usage behavior than a sparse-friendlier hashmap
|
||||||
|
@ -166,36 +191,31 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// collect FST iterators from all segments for this field
|
// collect FST iterators from all active segments for this field
|
||||||
|
var newDocNums [][]uint64
|
||||||
|
var drops []*roaring.Bitmap
|
||||||
var dicts []*Dictionary
|
var dicts []*Dictionary
|
||||||
var itrs []vellum.Iterator
|
var itrs []vellum.Iterator
|
||||||
for _, segment := range segments {
|
|
||||||
|
for segmentI, segment := range segments {
|
||||||
dict, err2 := segment.dictionary(fieldName)
|
dict, err2 := segment.dictionary(fieldName)
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
return nil, 0, err2
|
return nil, 0, err2
|
||||||
}
|
}
|
||||||
dicts = append(dicts, dict)
|
|
||||||
|
|
||||||
if dict != nil && dict.fst != nil {
|
if dict != nil && dict.fst != nil {
|
||||||
itr, err2 := dict.fst.Iterator(nil, nil)
|
itr, err2 := dict.fst.Iterator(nil, nil)
|
||||||
if err2 != nil && err2 != vellum.ErrIteratorDone {
|
if err2 != nil && err2 != vellum.ErrIteratorDone {
|
||||||
return nil, 0, err2
|
return nil, 0, err2
|
||||||
}
|
}
|
||||||
if itr != nil {
|
if itr != nil {
|
||||||
|
newDocNums = append(newDocNums, newDocNumsIn[segmentI])
|
||||||
|
drops = append(drops, dropsIn[segmentI])
|
||||||
|
dicts = append(dicts, dict)
|
||||||
itrs = append(itrs, itr)
|
itrs = append(itrs, itr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// create merging iterator
|
|
||||||
mergeItr, err := vellum.NewMergeIterator(itrs, func(postingOffsets []uint64) uint64 {
|
|
||||||
// we don't actually use the merged value
|
|
||||||
return 0
|
|
||||||
})
|
|
||||||
|
|
||||||
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
|
|
||||||
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
|
|
||||||
|
|
||||||
if uint64(cap(docTermMap)) < newSegDocCount {
|
if uint64(cap(docTermMap)) < newSegDocCount {
|
||||||
docTermMap = make([][]byte, newSegDocCount)
|
docTermMap = make([][]byte, newSegDocCount)
|
||||||
} else {
|
} else {
|
||||||
|
@ -205,70 +225,14 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for err == nil {
|
var prevTerm []byte
|
||||||
term, _ := mergeItr.Current()
|
|
||||||
|
|
||||||
newRoaring := roaring.NewBitmap()
|
newRoaring := roaring.NewBitmap()
|
||||||
newRoaringLocs := roaring.NewBitmap()
|
newRoaringLocs := roaring.NewBitmap()
|
||||||
|
|
||||||
tfEncoder.Reset()
|
finishTerm := func(term []byte) error {
|
||||||
locEncoder.Reset()
|
if term == nil {
|
||||||
|
return nil
|
||||||
// now go back and get posting list for this term
|
|
||||||
// but pass in the deleted docs for that segment
|
|
||||||
for dictI, dict := range dicts {
|
|
||||||
if dict == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
postings, err2 := dict.postingsList(term, drops[dictI])
|
|
||||||
if err2 != nil {
|
|
||||||
return nil, 0, err2
|
|
||||||
}
|
|
||||||
|
|
||||||
postItr := postings.Iterator()
|
|
||||||
next, err2 := postItr.Next()
|
|
||||||
for next != nil && err2 == nil {
|
|
||||||
hitNewDocNum := newDocNums[dictI][next.Number()]
|
|
||||||
if hitNewDocNum == docDropped {
|
|
||||||
return nil, 0, fmt.Errorf("see hit with dropped doc num")
|
|
||||||
}
|
|
||||||
newRoaring.Add(uint32(hitNewDocNum))
|
|
||||||
// encode norm bits
|
|
||||||
norm := next.Norm()
|
|
||||||
normBits := math.Float32bits(float32(norm))
|
|
||||||
err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits))
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
locs := next.Locations()
|
|
||||||
if len(locs) > 0 {
|
|
||||||
newRoaringLocs.Add(uint32(hitNewDocNum))
|
|
||||||
for _, loc := range locs {
|
|
||||||
if cap(bufLoc) < 5+len(loc.ArrayPositions()) {
|
|
||||||
bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions()))
|
|
||||||
}
|
|
||||||
args := bufLoc[0:5]
|
|
||||||
args[0] = uint64(fieldsMap[loc.Field()])
|
|
||||||
args[1] = loc.Pos()
|
|
||||||
args[2] = loc.Start()
|
|
||||||
args[3] = loc.End()
|
|
||||||
args[4] = uint64(len(loc.ArrayPositions()))
|
|
||||||
args = append(args, loc.ArrayPositions()...)
|
|
||||||
err = locEncoder.Add(hitNewDocNum, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
docTermMap[hitNewDocNum] =
|
|
||||||
append(append(docTermMap[hitNewDocNum], term...), termSeparator)
|
|
||||||
|
|
||||||
next, err2 = postItr.Next()
|
|
||||||
}
|
|
||||||
if err2 != nil {
|
|
||||||
return nil, 0, err2
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tfEncoder.Close()
|
tfEncoder.Close()
|
||||||
|
@ -277,59 +241,142 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
if newRoaring.GetCardinality() > 0 {
|
if newRoaring.GetCardinality() > 0 {
|
||||||
// this field/term actually has hits in the new segment, lets write it down
|
// this field/term actually has hits in the new segment, lets write it down
|
||||||
freqOffset := uint64(w.Count())
|
freqOffset := uint64(w.Count())
|
||||||
_, err = tfEncoder.Write(w)
|
_, err := tfEncoder.Write(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
locOffset := uint64(w.Count())
|
locOffset := uint64(w.Count())
|
||||||
_, err = locEncoder.Write(w)
|
_, err = locEncoder.Write(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
postingLocOffset := uint64(w.Count())
|
postingLocOffset := uint64(w.Count())
|
||||||
_, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64)
|
_, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
postingOffset := uint64(w.Count())
|
postingOffset := uint64(w.Count())
|
||||||
|
|
||||||
// write out the start of the term info
|
// write out the start of the term info
|
||||||
buf := bufMaxVarintLen64
|
n := binary.PutUvarint(bufMaxVarintLen64, freqOffset)
|
||||||
n := binary.PutUvarint(buf, freqOffset)
|
_, err = w.Write(bufMaxVarintLen64[:n])
|
||||||
_, err = w.Write(buf[:n])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// write out the start of the loc info
|
// write out the start of the loc info
|
||||||
n = binary.PutUvarint(buf, locOffset)
|
n = binary.PutUvarint(bufMaxVarintLen64, locOffset)
|
||||||
_, err = w.Write(buf[:n])
|
_, err = w.Write(bufMaxVarintLen64[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
|
// write out the start of the posting locs
|
||||||
// write out the start of the loc posting list
|
n = binary.PutUvarint(bufMaxVarintLen64, postingLocOffset)
|
||||||
n = binary.PutUvarint(buf, postingLocOffset)
|
_, err = w.Write(bufMaxVarintLen64[:n])
|
||||||
_, err = w.Write(buf[:n])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
_, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64)
|
_, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newVellum.Insert(term, postingOffset)
|
err = newVellum.Insert(term, postingOffset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = mergeItr.Next()
|
newRoaring = roaring.NewBitmap()
|
||||||
|
newRoaringLocs = roaring.NewBitmap()
|
||||||
|
|
||||||
|
tfEncoder.Reset()
|
||||||
|
locEncoder.Reset()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
enumerator, err := newEnumerator(itrs)
|
||||||
|
|
||||||
|
for err == nil {
|
||||||
|
term, itrI, postingsOffset := enumerator.Current()
|
||||||
|
|
||||||
|
if !bytes.Equal(prevTerm, term) {
|
||||||
|
// if the term changed, write out the info collected
|
||||||
|
// for the previous term
|
||||||
|
err2 := finishTerm(prevTerm)
|
||||||
|
if err2 != nil {
|
||||||
|
return nil, 0, err2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var err2 error
|
||||||
|
postings, err2 = dicts[itrI].postingsListFromOffset(
|
||||||
|
postingsOffset, drops[itrI], postings)
|
||||||
|
if err2 != nil {
|
||||||
|
return nil, 0, err2
|
||||||
|
}
|
||||||
|
|
||||||
|
newDocNumsI := newDocNums[itrI]
|
||||||
|
|
||||||
|
postItr = postings.iterator(postItr)
|
||||||
|
next, err2 := postItr.Next()
|
||||||
|
for next != nil && err2 == nil {
|
||||||
|
hitNewDocNum := newDocNumsI[next.Number()]
|
||||||
|
if hitNewDocNum == docDropped {
|
||||||
|
return nil, 0, fmt.Errorf("see hit with dropped doc num")
|
||||||
|
}
|
||||||
|
newRoaring.Add(uint32(hitNewDocNum))
|
||||||
|
// encode norm bits
|
||||||
|
norm := next.Norm()
|
||||||
|
normBits := math.Float32bits(float32(norm))
|
||||||
|
err = tfEncoder.Add(hitNewDocNum, next.Frequency(), uint64(normBits))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
locs := next.Locations()
|
||||||
|
if len(locs) > 0 {
|
||||||
|
newRoaringLocs.Add(uint32(hitNewDocNum))
|
||||||
|
for _, loc := range locs {
|
||||||
|
if cap(bufLoc) < 5+len(loc.ArrayPositions()) {
|
||||||
|
bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions()))
|
||||||
|
}
|
||||||
|
args := bufLoc[0:5]
|
||||||
|
args[0] = uint64(fieldsMap[loc.Field()] - 1)
|
||||||
|
args[1] = loc.Pos()
|
||||||
|
args[2] = loc.Start()
|
||||||
|
args[3] = loc.End()
|
||||||
|
args[4] = uint64(len(loc.ArrayPositions()))
|
||||||
|
args = append(args, loc.ArrayPositions()...)
|
||||||
|
err = locEncoder.Add(hitNewDocNum, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
docTermMap[hitNewDocNum] =
|
||||||
|
append(append(docTermMap[hitNewDocNum], term...), termSeparator)
|
||||||
|
|
||||||
|
next, err2 = postItr.Next()
|
||||||
|
}
|
||||||
|
if err2 != nil {
|
||||||
|
return nil, 0, err2
|
||||||
|
}
|
||||||
|
|
||||||
|
prevTerm = prevTerm[:0] // copy to prevTerm in case Next() reuses term mem
|
||||||
|
prevTerm = append(prevTerm, term...)
|
||||||
|
|
||||||
|
err = enumerator.Next()
|
||||||
}
|
}
|
||||||
if err != nil && err != vellum.ErrIteratorDone {
|
if err != nil && err != vellum.ErrIteratorDone {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = finishTerm(prevTerm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
dictOffset := uint64(w.Count())
|
dictOffset := uint64(w.Count())
|
||||||
|
|
||||||
err = newVellum.Close()
|
err = newVellum.Close()
|
||||||
|
@ -378,7 +425,7 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldDvLocsOffset = uint64(w.Count())
|
fieldDvLocsOffset := uint64(w.Count())
|
||||||
|
|
||||||
buf := bufMaxVarintLen64
|
buf := bufMaxVarintLen64
|
||||||
for _, offset := range fieldDvLocs {
|
for _, offset := range fieldDvLocs {
|
||||||
|
@ -392,10 +439,8 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
return rv, fieldDvLocsOffset, nil
|
return rv, fieldDvLocsOffset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const docDropped = math.MaxUint64
|
func mergeStoredAndRemap(segments []*SegmentBase, drops []*roaring.Bitmap,
|
||||||
|
fieldsMap map[string]uint16, fieldsInv []string, fieldsSame bool, newSegDocCount uint64,
|
||||||
func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
|
|
||||||
fieldsMap map[string]uint16, fieldsInv []string, newSegDocCount uint64,
|
|
||||||
w *CountHashWriter) (uint64, [][]uint64, error) {
|
w *CountHashWriter) (uint64, [][]uint64, error) {
|
||||||
var rv [][]uint64 // The remapped or newDocNums for each segment.
|
var rv [][]uint64 // The remapped or newDocNums for each segment.
|
||||||
|
|
||||||
|
@ -417,10 +462,30 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
for segI, segment := range segments {
|
for segI, segment := range segments {
|
||||||
segNewDocNums := make([]uint64, segment.numDocs)
|
segNewDocNums := make([]uint64, segment.numDocs)
|
||||||
|
|
||||||
|
dropsI := drops[segI]
|
||||||
|
|
||||||
|
// optimize when the field mapping is the same across all
|
||||||
|
// segments and there are no deletions, via byte-copying
|
||||||
|
// of stored docs bytes directly to the writer
|
||||||
|
if fieldsSame && (dropsI == nil || dropsI.GetCardinality() == 0) {
|
||||||
|
err := segment.copyStoredDocs(newDocNum, docNumOffsets, w)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := uint64(0); i < segment.numDocs; i++ {
|
||||||
|
segNewDocNums[i] = newDocNum
|
||||||
|
newDocNum++
|
||||||
|
}
|
||||||
|
rv = append(rv, segNewDocNums)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// for each doc num
|
// for each doc num
|
||||||
for docNum := uint64(0); docNum < segment.numDocs; docNum++ {
|
for docNum := uint64(0); docNum < segment.numDocs; docNum++ {
|
||||||
// TODO: roaring's API limits docNums to 32-bits?
|
// TODO: roaring's API limits docNums to 32-bits?
|
||||||
if drops[segI] != nil && drops[segI].Contains(uint32(docNum)) {
|
if dropsI != nil && dropsI.Contains(uint32(docNum)) {
|
||||||
segNewDocNums[docNum] = docDropped
|
segNewDocNums[docNum] = docDropped
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -439,7 +504,7 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
poss[i] = poss[i][:0]
|
poss[i] = poss[i][:0]
|
||||||
}
|
}
|
||||||
err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool {
|
err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool {
|
||||||
fieldID := int(fieldsMap[field])
|
fieldID := int(fieldsMap[field]) - 1
|
||||||
vals[fieldID] = append(vals[fieldID], value)
|
vals[fieldID] = append(vals[fieldID], value)
|
||||||
typs[fieldID] = append(typs[fieldID], typ)
|
typs[fieldID] = append(typs[fieldID], typ)
|
||||||
poss[fieldID] = append(poss[fieldID], pos)
|
poss[fieldID] = append(poss[fieldID], pos)
|
||||||
|
@ -453,47 +518,14 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
for fieldID := range fieldsInv {
|
for fieldID := range fieldsInv {
|
||||||
storedFieldValues := vals[int(fieldID)]
|
storedFieldValues := vals[int(fieldID)]
|
||||||
|
|
||||||
// has stored values for this field
|
stf := typs[int(fieldID)]
|
||||||
num := len(storedFieldValues)
|
spf := poss[int(fieldID)]
|
||||||
|
|
||||||
// process each value
|
var err2 error
|
||||||
for i := 0; i < num; i++ {
|
curr, data, err2 = persistStoredFieldValues(fieldID,
|
||||||
// encode field
|
storedFieldValues, stf, spf, curr, metaEncoder, data)
|
||||||
_, err2 := metaEncoder.PutU64(uint64(fieldID))
|
if err2 != nil {
|
||||||
if err2 != nil {
|
return 0, nil, err2
|
||||||
return 0, nil, err2
|
|
||||||
}
|
|
||||||
// encode type
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(typs[int(fieldID)][i]))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, nil, err2
|
|
||||||
}
|
|
||||||
// encode start offset
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(curr))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, nil, err2
|
|
||||||
}
|
|
||||||
// end len
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, nil, err2
|
|
||||||
}
|
|
||||||
// encode number of array pos
|
|
||||||
_, err2 = metaEncoder.PutU64(uint64(len(poss[int(fieldID)][i])))
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, nil, err2
|
|
||||||
}
|
|
||||||
// encode all array positions
|
|
||||||
for j := 0; j < len(poss[int(fieldID)][i]); j++ {
|
|
||||||
_, err2 = metaEncoder.PutU64(poss[int(fieldID)][i][j])
|
|
||||||
if err2 != nil {
|
|
||||||
return 0, nil, err2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// append data
|
|
||||||
data = append(data, storedFieldValues[i]...)
|
|
||||||
// update curr
|
|
||||||
curr += len(storedFieldValues[i])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -528,36 +560,87 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
|
||||||
}
|
}
|
||||||
|
|
||||||
// return value is the start of the stored index
|
// return value is the start of the stored index
|
||||||
offset := uint64(w.Count())
|
storedIndexOffset := uint64(w.Count())
|
||||||
|
|
||||||
// now write out the stored doc index
|
// now write out the stored doc index
|
||||||
for docNum := range docNumOffsets {
|
for _, docNumOffset := range docNumOffsets {
|
||||||
err := binary.Write(w, binary.BigEndian, docNumOffsets[docNum])
|
err := binary.Write(w, binary.BigEndian, docNumOffset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return offset, rv, nil
|
return storedIndexOffset, rv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeFields builds a unified list of fields used across all the input segments
|
// copyStoredDocs writes out a segment's stored doc info, optimized by
|
||||||
func mergeFields(segments []*Segment) []string {
|
// using a single Write() call for the entire set of bytes. The
|
||||||
fieldsMap := map[string]struct{}{}
|
// newDocNumOffsets is filled with the new offsets for each doc.
|
||||||
|
func (s *SegmentBase) copyStoredDocs(newDocNum uint64, newDocNumOffsets []uint64,
|
||||||
|
w *CountHashWriter) error {
|
||||||
|
if s.numDocs <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
indexOffset0, storedOffset0, _, _, _ :=
|
||||||
|
s.getDocStoredOffsets(0) // the segment's first doc
|
||||||
|
|
||||||
|
indexOffsetN, storedOffsetN, readN, metaLenN, dataLenN :=
|
||||||
|
s.getDocStoredOffsets(s.numDocs - 1) // the segment's last doc
|
||||||
|
|
||||||
|
storedOffset0New := uint64(w.Count())
|
||||||
|
|
||||||
|
storedBytes := s.mem[storedOffset0 : storedOffsetN+readN+metaLenN+dataLenN]
|
||||||
|
_, err := w.Write(storedBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// remap the storedOffset's for the docs into new offsets relative
|
||||||
|
// to storedOffset0New, filling the given docNumOffsetsOut array
|
||||||
|
for indexOffset := indexOffset0; indexOffset <= indexOffsetN; indexOffset += 8 {
|
||||||
|
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
|
||||||
|
storedOffsetNew := storedOffset - storedOffset0 + storedOffset0New
|
||||||
|
newDocNumOffsets[newDocNum] = storedOffsetNew
|
||||||
|
newDocNum += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeFields builds a unified list of fields used across all the
|
||||||
|
// input segments, and computes whether the fields are the same across
|
||||||
|
// segments (which depends on fields to be sorted in the same way
|
||||||
|
// across segments)
|
||||||
|
func mergeFields(segments []*SegmentBase) (bool, []string) {
|
||||||
|
fieldsSame := true
|
||||||
|
|
||||||
|
var segment0Fields []string
|
||||||
|
if len(segments) > 0 {
|
||||||
|
segment0Fields = segments[0].Fields()
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldsExist := map[string]struct{}{}
|
||||||
for _, segment := range segments {
|
for _, segment := range segments {
|
||||||
fields := segment.Fields()
|
fields := segment.Fields()
|
||||||
for _, field := range fields {
|
for fieldi, field := range fields {
|
||||||
fieldsMap[field] = struct{}{}
|
fieldsExist[field] = struct{}{}
|
||||||
|
if len(segment0Fields) != len(fields) || segment0Fields[fieldi] != field {
|
||||||
|
fieldsSame = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rv := make([]string, 0, len(fieldsMap))
|
rv := make([]string, 0, len(fieldsExist))
|
||||||
// ensure _id stays first
|
// ensure _id stays first
|
||||||
rv = append(rv, "_id")
|
rv = append(rv, "_id")
|
||||||
for k := range fieldsMap {
|
for k := range fieldsExist {
|
||||||
if k != "_id" {
|
if k != "_id" {
|
||||||
rv = append(rv, k)
|
rv = append(rv, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rv
|
|
||||||
|
sort.Strings(rv[1:]) // leave _id as first
|
||||||
|
|
||||||
|
return fieldsSame, rv
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,21 +28,27 @@ import (
|
||||||
// PostingsList is an in-memory represenation of a postings list
|
// PostingsList is an in-memory represenation of a postings list
|
||||||
type PostingsList struct {
|
type PostingsList struct {
|
||||||
sb *SegmentBase
|
sb *SegmentBase
|
||||||
term []byte
|
|
||||||
postingsOffset uint64
|
postingsOffset uint64
|
||||||
freqOffset uint64
|
freqOffset uint64
|
||||||
locOffset uint64
|
locOffset uint64
|
||||||
locBitmap *roaring.Bitmap
|
locBitmap *roaring.Bitmap
|
||||||
postings *roaring.Bitmap
|
postings *roaring.Bitmap
|
||||||
except *roaring.Bitmap
|
except *roaring.Bitmap
|
||||||
postingKey []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterator returns an iterator for this postings list
|
// Iterator returns an iterator for this postings list
|
||||||
func (p *PostingsList) Iterator() segment.PostingsIterator {
|
func (p *PostingsList) Iterator() segment.PostingsIterator {
|
||||||
rv := &PostingsIterator{
|
return p.iterator(nil)
|
||||||
postings: p,
|
}
|
||||||
|
|
||||||
|
func (p *PostingsList) iterator(rv *PostingsIterator) *PostingsIterator {
|
||||||
|
if rv == nil {
|
||||||
|
rv = &PostingsIterator{}
|
||||||
|
} else {
|
||||||
|
*rv = PostingsIterator{} // clear the struct
|
||||||
}
|
}
|
||||||
|
rv.postings = p
|
||||||
|
|
||||||
if p.postings != nil {
|
if p.postings != nil {
|
||||||
// prepare the freq chunk details
|
// prepare the freq chunk details
|
||||||
var n uint64
|
var n uint64
|
||||||
|
|
|
@ -17,15 +17,27 @@ package zap
|
||||||
import "encoding/binary"
|
import "encoding/binary"
|
||||||
|
|
||||||
func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) {
|
func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) {
|
||||||
docStoredStartAddr := s.storedIndexOffset + (8 * docNum)
|
_, storedOffset, n, metaLen, dataLen := s.getDocStoredOffsets(docNum)
|
||||||
docStoredStart := binary.BigEndian.Uint64(s.mem[docStoredStartAddr : docStoredStartAddr+8])
|
|
||||||
var n uint64
|
meta := s.mem[storedOffset+n : storedOffset+n+metaLen]
|
||||||
metaLen, read := binary.Uvarint(s.mem[docStoredStart : docStoredStart+binary.MaxVarintLen64])
|
data := s.mem[storedOffset+n+metaLen : storedOffset+n+metaLen+dataLen]
|
||||||
n += uint64(read)
|
|
||||||
var dataLen uint64
|
|
||||||
dataLen, read = binary.Uvarint(s.mem[docStoredStart+n : docStoredStart+n+binary.MaxVarintLen64])
|
|
||||||
n += uint64(read)
|
|
||||||
meta := s.mem[docStoredStart+n : docStoredStart+n+metaLen]
|
|
||||||
data := s.mem[docStoredStart+n+metaLen : docStoredStart+n+metaLen+dataLen]
|
|
||||||
return meta, data
|
return meta, data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SegmentBase) getDocStoredOffsets(docNum uint64) (
|
||||||
|
uint64, uint64, uint64, uint64, uint64) {
|
||||||
|
indexOffset := s.storedIndexOffset + (8 * docNum)
|
||||||
|
|
||||||
|
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
|
||||||
|
|
||||||
|
var n uint64
|
||||||
|
|
||||||
|
metaLen, read := binary.Uvarint(s.mem[storedOffset : storedOffset+binary.MaxVarintLen64])
|
||||||
|
n += uint64(read)
|
||||||
|
|
||||||
|
dataLen, read := binary.Uvarint(s.mem[storedOffset+n : storedOffset+n+binary.MaxVarintLen64])
|
||||||
|
n += uint64(read)
|
||||||
|
|
||||||
|
return indexOffset, storedOffset, n, metaLen, dataLen
|
||||||
|
}
|
||||||
|
|
|
@ -343,8 +343,9 @@ func (s *SegmentBase) DocNumbers(ids []string) (*roaring.Bitmap, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var postings *PostingsList
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
postings, err := idDict.postingsList([]byte(id), nil)
|
postings, err = idDict.postingsList([]byte(id), nil, postings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,10 +31,9 @@ func (r *RollbackPoint) GetInternal(key []byte) []byte {
|
||||||
return r.meta[string(key)]
|
return r.meta[string(key)]
|
||||||
}
|
}
|
||||||
|
|
||||||
// RollbackPoints returns an array of rollback points available
|
// RollbackPoints returns an array of rollback points available for
|
||||||
// for the application to make a decision on where to rollback
|
// the application to rollback to, with more recent rollback points
|
||||||
// to. A nil return value indicates that there are no available
|
// (higher epochs) coming first.
|
||||||
// rollback points.
|
|
||||||
func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
|
func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
|
||||||
if s.rootBolt == nil {
|
if s.rootBolt == nil {
|
||||||
return nil, fmt.Errorf("RollbackPoints: root is nil")
|
return nil, fmt.Errorf("RollbackPoints: root is nil")
|
||||||
|
@ -54,7 +53,7 @@ func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
|
||||||
|
|
||||||
snapshots := tx.Bucket(boltSnapshotsBucket)
|
snapshots := tx.Bucket(boltSnapshotsBucket)
|
||||||
if snapshots == nil {
|
if snapshots == nil {
|
||||||
return nil, fmt.Errorf("RollbackPoints: no snapshots available")
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rollbackPoints := []*RollbackPoint{}
|
rollbackPoints := []*RollbackPoint{}
|
||||||
|
@ -150,10 +149,7 @@ func (s *Scorch) Rollback(to *RollbackPoint) error {
|
||||||
|
|
||||||
revert.snapshot = indexSnapshot
|
revert.snapshot = indexSnapshot
|
||||||
revert.applied = make(chan error)
|
revert.applied = make(chan error)
|
||||||
|
revert.persisted = make(chan error)
|
||||||
if !s.unsafeBatch {
|
|
||||||
revert.persisted = make(chan error)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -173,9 +169,5 @@ func (s *Scorch) Rollback(to *RollbackPoint) error {
|
||||||
return fmt.Errorf("Rollback: failed with err: %v", err)
|
return fmt.Errorf("Rollback: failed with err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if revert.persisted != nil {
|
return <-revert.persisted
|
||||||
err = <-revert.persisted
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -837,6 +837,11 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) {
|
||||||
docBackIndexRowErr = err
|
docBackIndexRowErr = err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
if cerr := kvreader.Close(); err == nil && cerr != nil {
|
||||||
|
docBackIndexRowErr = cerr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for docID, doc := range batch.IndexOps {
|
for docID, doc := range batch.IndexOps {
|
||||||
backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID))
|
backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID))
|
||||||
|
@ -847,12 +852,6 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) {
|
||||||
|
|
||||||
docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow}
|
docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = kvreader.Close()
|
|
||||||
if err != nil {
|
|
||||||
docBackIndexRowErr = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait for analysis result
|
// wait for analysis result
|
||||||
|
|
|
@ -15,12 +15,11 @@
|
||||||
package bleve
|
package bleve
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
|
|
||||||
"github.com/blevesearch/bleve/document"
|
"github.com/blevesearch/bleve/document"
|
||||||
"github.com/blevesearch/bleve/index"
|
"github.com/blevesearch/bleve/index"
|
||||||
"github.com/blevesearch/bleve/index/store"
|
"github.com/blevesearch/bleve/index/store"
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package bleve
|
package bleve
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -22,8 +23,6 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
|
|
||||||
"github.com/blevesearch/bleve/document"
|
"github.com/blevesearch/bleve/document"
|
||||||
"github.com/blevesearch/bleve/index"
|
"github.com/blevesearch/bleve/index"
|
||||||
"github.com/blevesearch/bleve/index/store"
|
"github.com/blevesearch/bleve/index/store"
|
||||||
|
|
|
@ -15,11 +15,10 @@
|
||||||
package search
|
package search
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/blevesearch/bleve/index"
|
"github.com/blevesearch/bleve/index"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Collector interface {
|
type Collector interface {
|
||||||
|
|
|
@ -15,11 +15,11 @@
|
||||||
package collector
|
package collector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/blevesearch/bleve/index"
|
"github.com/blevesearch/bleve/index"
|
||||||
"github.com/blevesearch/bleve/search"
|
"github.com/blevesearch/bleve/search"
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type collectorStore interface {
|
type collectorStore interface {
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
package goth
|
package goth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,17 +4,18 @@ package facebook
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"github.com/markbates/goth"
|
"github.com/markbates/goth"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
@ -22,7 +23,7 @@ import (
|
||||||
const (
|
const (
|
||||||
authURL string = "https://www.facebook.com/dialog/oauth"
|
authURL string = "https://www.facebook.com/dialog/oauth"
|
||||||
tokenURL string = "https://graph.facebook.com/oauth/access_token"
|
tokenURL string = "https://graph.facebook.com/oauth/access_token"
|
||||||
endpointProfile string = "https://graph.facebook.com/me?fields=email,first_name,last_name,link,about,id,name,picture,location"
|
endpointProfile string = "https://graph.facebook.com/me?fields="
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a new Facebook provider, and sets up important connection details.
|
// New creates a new Facebook provider, and sets up important connection details.
|
||||||
|
@ -68,9 +69,9 @@ func (p *Provider) Debug(debug bool) {}
|
||||||
|
|
||||||
// BeginAuth asks Facebook for an authentication end-point.
|
// BeginAuth asks Facebook for an authentication end-point.
|
||||||
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
|
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
|
||||||
url := p.config.AuthCodeURL(state)
|
authUrl := p.config.AuthCodeURL(state)
|
||||||
session := &Session{
|
session := &Session{
|
||||||
AuthURL: url,
|
AuthURL: authUrl,
|
||||||
}
|
}
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
@ -96,7 +97,15 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
|
||||||
hash.Write([]byte(sess.AccessToken))
|
hash.Write([]byte(sess.AccessToken))
|
||||||
appsecretProof := hex.EncodeToString(hash.Sum(nil))
|
appsecretProof := hex.EncodeToString(hash.Sum(nil))
|
||||||
|
|
||||||
response, err := p.Client().Get(endpointProfile + "&access_token=" + url.QueryEscape(sess.AccessToken) + "&appsecret_proof=" + appsecretProof)
|
reqUrl := fmt.Sprint(
|
||||||
|
endpointProfile,
|
||||||
|
strings.Join(p.config.Scopes, ","),
|
||||||
|
"&access_token=",
|
||||||
|
url.QueryEscape(sess.AccessToken),
|
||||||
|
"&appsecret_proof=",
|
||||||
|
appsecretProof,
|
||||||
|
)
|
||||||
|
response, err := p.Client().Get(reqUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
@ -168,17 +177,31 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config {
|
||||||
},
|
},
|
||||||
Scopes: []string{
|
Scopes: []string{
|
||||||
"email",
|
"email",
|
||||||
|
"first_name",
|
||||||
|
"last_name",
|
||||||
|
"link",
|
||||||
|
"about",
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"picture",
|
||||||
|
"location",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultScopes := map[string]struct{}{
|
// creates possibility to invoke field method like 'picture.type(large)'
|
||||||
"email": {},
|
var found bool
|
||||||
}
|
for _, sc := range scopes {
|
||||||
|
sc := sc
|
||||||
for _, scope := range scopes {
|
for i, defScope := range c.Scopes {
|
||||||
if _, exists := defaultScopes[scope]; !exists {
|
if defScope == strings.Split(sc, ".")[0] {
|
||||||
c.Scopes = append(c.Scopes, scope)
|
c.Scopes[i] = sc
|
||||||
|
found = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if !found {
|
||||||
|
c.Scopes = append(c.Scopes, sc)
|
||||||
|
}
|
||||||
|
found = false
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
|
||||||
|
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Do sends an HTTP request with the provided http.Client and returns
|
||||||
|
// an HTTP response.
|
||||||
|
//
|
||||||
|
// If the client is nil, http.DefaultClient is used.
|
||||||
|
//
|
||||||
|
// The provided ctx must be non-nil. If it is canceled or times out,
|
||||||
|
// ctx.Err() will be returned.
|
||||||
|
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
|
||||||
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req.WithContext(ctx))
|
||||||
|
// If we got an error, and the context has been canceled,
|
||||||
|
// the context's error is probably more useful.
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get issues a GET request via the Do function.
|
||||||
|
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return Do(ctx, client, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Head issues a HEAD request via the Do function.
|
||||||
|
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("HEAD", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return Do(ctx, client, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post issues a POST request via the Do function.
|
||||||
|
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("POST", url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", bodyType)
|
||||||
|
return Do(ctx, client, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostForm issues a POST request via the Do function.
|
||||||
|
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
|
||||||
|
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build !go1.7
|
||||||
|
|
||||||
|
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func nop() {}
|
||||||
|
|
||||||
|
var (
|
||||||
|
testHookContextDoneBeforeHeaders = nop
|
||||||
|
testHookDoReturned = nop
|
||||||
|
testHookDidBodyClose = nop
|
||||||
|
)
|
||||||
|
|
||||||
|
// Do sends an HTTP request with the provided http.Client and returns an HTTP response.
|
||||||
|
// If the client is nil, http.DefaultClient is used.
|
||||||
|
// If the context is canceled or times out, ctx.Err() will be returned.
|
||||||
|
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
|
||||||
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(djd): Respect any existing value of req.Cancel.
|
||||||
|
cancel := make(chan struct{})
|
||||||
|
req.Cancel = cancel
|
||||||
|
|
||||||
|
type responseAndError struct {
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
result := make(chan responseAndError, 1)
|
||||||
|
|
||||||
|
// Make local copies of test hooks closed over by goroutines below.
|
||||||
|
// Prevents data races in tests.
|
||||||
|
testHookDoReturned := testHookDoReturned
|
||||||
|
testHookDidBodyClose := testHookDidBodyClose
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
testHookDoReturned()
|
||||||
|
result <- responseAndError{resp, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
testHookContextDoneBeforeHeaders()
|
||||||
|
close(cancel)
|
||||||
|
// Clean up after the goroutine calling client.Do:
|
||||||
|
go func() {
|
||||||
|
if r := <-result; r.resp != nil {
|
||||||
|
testHookDidBodyClose()
|
||||||
|
r.resp.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case r := <-result:
|
||||||
|
var err error
|
||||||
|
resp, err = r.resp, r.err
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(cancel)
|
||||||
|
case <-c:
|
||||||
|
// The response's Body is closed.
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
resp.Body = ¬ifyingReader{resp.Body, c}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get issues a GET request via the Do function.
|
||||||
|
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return Do(ctx, client, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Head issues a HEAD request via the Do function.
|
||||||
|
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("HEAD", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return Do(ctx, client, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post issues a POST request via the Do function.
|
||||||
|
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("POST", url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", bodyType)
|
||||||
|
return Do(ctx, client, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostForm issues a POST request via the Do function.
|
||||||
|
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
|
||||||
|
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// notifyingReader is an io.ReadCloser that closes the notify channel after
|
||||||
|
// Close is called or a Read fails on the underlying ReadCloser.
|
||||||
|
type notifyingReader struct {
|
||||||
|
io.ReadCloser
|
||||||
|
notify chan<- struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notifyingReader) Read(p []byte) (int, error) {
|
||||||
|
n, err := r.ReadCloser.Read(p)
|
||||||
|
if err != nil && r.notify != nil {
|
||||||
|
close(r.notify)
|
||||||
|
r.notify = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notifyingReader) Close() error {
|
||||||
|
err := r.ReadCloser.Close()
|
||||||
|
if r.notify != nil {
|
||||||
|
close(r.notify)
|
||||||
|
r.notify = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
Copyright (c) 2009 The oauth2 Authors. All rights reserved.
|
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
Redistribution and use in source and binary forms, with or without
|
||||||
modification, are permitted provided that the following conditions are
|
modification, are permitted provided that the following conditions are
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build appengine
|
|
||||||
|
|
||||||
// App Engine hooks.
|
|
||||||
|
|
||||||
package oauth2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
"golang.org/x/oauth2/internal"
|
|
||||||
"google.golang.org/appengine/urlfetch"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
internal.RegisterContextClientFunc(contextClientAppEngine)
|
|
||||||
}
|
|
||||||
|
|
||||||
func contextClientAppEngine(ctx context.Context) (*http.Client, error) {
|
|
||||||
return urlfetch.Client(ctx), nil
|
|
||||||
}
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
// Copyright 2018 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build appengine
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import "google.golang.org/appengine/urlfetch"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
appengineClientHook = urlfetch.Client
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
// Copyright 2017 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package internal contains support packages for oauth2 package.
|
||||||
|
package internal
|
|
@ -2,18 +2,14 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
// Package internal contains support packages for oauth2 package.
|
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseKey converts the binary contents of a private key file
|
// ParseKey converts the binary contents of a private key file
|
||||||
|
@ -30,7 +26,7 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
|
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err)
|
return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
parsed, ok := parsedKey.(*rsa.PrivateKey)
|
parsed, ok := parsedKey.(*rsa.PrivateKey)
|
||||||
|
@ -39,38 +35,3 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
|
||||||
}
|
}
|
||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseINI(ini io.Reader) (map[string]map[string]string, error) {
|
|
||||||
result := map[string]map[string]string{
|
|
||||||
"": map[string]string{}, // root section
|
|
||||||
}
|
|
||||||
scanner := bufio.NewScanner(ini)
|
|
||||||
currentSection := ""
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := strings.TrimSpace(scanner.Text())
|
|
||||||
if strings.HasPrefix(line, ";") {
|
|
||||||
// comment.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
|
|
||||||
currentSection = strings.TrimSpace(line[1 : len(line)-1])
|
|
||||||
result[currentSection] = map[string]string{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parts := strings.SplitN(line, "=", 2)
|
|
||||||
if len(parts) == 2 && parts[0] != "" {
|
|
||||||
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("error scanning ini: %v", err)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CondVal(v string) []string {
|
|
||||||
if v == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return []string{v}
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,11 +2,12 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
// Package internal contains support packages for oauth2 package.
|
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -17,10 +18,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context/ctxhttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Token represents the crendentials used to authorize
|
// Token represents the credentials used to authorize
|
||||||
// the requests to access protected resources on the OAuth 2.0
|
// the requests to access protected resources on the OAuth 2.0
|
||||||
// provider's backend.
|
// provider's backend.
|
||||||
//
|
//
|
||||||
|
@ -91,6 +92,7 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
var brokenAuthHeaderProviders = []string{
|
var brokenAuthHeaderProviders = []string{
|
||||||
"https://accounts.google.com/",
|
"https://accounts.google.com/",
|
||||||
|
"https://api.codeswholesale.com/oauth/token",
|
||||||
"https://api.dropbox.com/",
|
"https://api.dropbox.com/",
|
||||||
"https://api.dropboxapi.com/",
|
"https://api.dropboxapi.com/",
|
||||||
"https://api.instagram.com/",
|
"https://api.instagram.com/",
|
||||||
|
@ -99,10 +101,16 @@ var brokenAuthHeaderProviders = []string{
|
||||||
"https://api.pushbullet.com/",
|
"https://api.pushbullet.com/",
|
||||||
"https://api.soundcloud.com/",
|
"https://api.soundcloud.com/",
|
||||||
"https://api.twitch.tv/",
|
"https://api.twitch.tv/",
|
||||||
|
"https://id.twitch.tv/",
|
||||||
"https://app.box.com/",
|
"https://app.box.com/",
|
||||||
|
"https://api.box.com/",
|
||||||
"https://connect.stripe.com/",
|
"https://connect.stripe.com/",
|
||||||
|
"https://login.mailchimp.com/",
|
||||||
"https://login.microsoftonline.com/",
|
"https://login.microsoftonline.com/",
|
||||||
"https://login.salesforce.com/",
|
"https://login.salesforce.com/",
|
||||||
|
"https://login.windows.net",
|
||||||
|
"https://login.live.com/",
|
||||||
|
"https://login.live-int.com/",
|
||||||
"https://oauth.sandbox.trainingpeaks.com/",
|
"https://oauth.sandbox.trainingpeaks.com/",
|
||||||
"https://oauth.trainingpeaks.com/",
|
"https://oauth.trainingpeaks.com/",
|
||||||
"https://oauth.vk.com/",
|
"https://oauth.vk.com/",
|
||||||
|
@ -117,6 +125,24 @@ var brokenAuthHeaderProviders = []string{
|
||||||
"https://www.strava.com/oauth/",
|
"https://www.strava.com/oauth/",
|
||||||
"https://www.wunderlist.com/oauth/",
|
"https://www.wunderlist.com/oauth/",
|
||||||
"https://api.patreon.com/",
|
"https://api.patreon.com/",
|
||||||
|
"https://sandbox.codeswholesale.com/oauth/token",
|
||||||
|
"https://api.sipgate.com/v1/authorization/oauth",
|
||||||
|
"https://api.medium.com/v1/tokens",
|
||||||
|
"https://log.finalsurge.com/oauth/token",
|
||||||
|
"https://multisport.todaysplan.com.au/rest/oauth/access_token",
|
||||||
|
"https://whats.todaysplan.com.au/rest/oauth/access_token",
|
||||||
|
"https://stackoverflow.com/oauth/access_token",
|
||||||
|
"https://account.health.nokia.com",
|
||||||
|
"https://accounts.zoho.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
|
||||||
|
var brokenAuthHeaderDomains = []string{
|
||||||
|
".auth0.com",
|
||||||
|
".force.com",
|
||||||
|
".myshopify.com",
|
||||||
|
".okta.com",
|
||||||
|
".oktapreview.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
|
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
|
||||||
|
@ -139,6 +165,14 @@ func providerAuthHeaderWorks(tokenURL string) bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if u, err := url.Parse(tokenURL); err == nil {
|
||||||
|
for _, s := range brokenAuthHeaderDomains {
|
||||||
|
if strings.HasSuffix(u.Host, s) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Assume the provider implements the spec properly
|
// Assume the provider implements the spec properly
|
||||||
// otherwise. We can add more exceptions as they're
|
// otherwise. We can add more exceptions as they're
|
||||||
// discovered. We will _not_ be adding configurable hooks
|
// discovered. We will _not_ be adding configurable hooks
|
||||||
|
@ -147,14 +181,14 @@ func providerAuthHeaderWorks(tokenURL string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
|
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
|
||||||
hc, err := ContextClient(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
v.Set("client_id", clientID)
|
|
||||||
bustedAuth := !providerAuthHeaderWorks(tokenURL)
|
bustedAuth := !providerAuthHeaderWorks(tokenURL)
|
||||||
if bustedAuth && clientSecret != "" {
|
if bustedAuth {
|
||||||
v.Set("client_secret", clientSecret)
|
if clientID != "" {
|
||||||
|
v.Set("client_id", clientID)
|
||||||
|
}
|
||||||
|
if clientSecret != "" {
|
||||||
|
v.Set("client_secret", clientSecret)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
|
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -162,9 +196,9 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
if !bustedAuth {
|
if !bustedAuth {
|
||||||
req.SetBasicAuth(clientID, clientSecret)
|
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
|
||||||
}
|
}
|
||||||
r, err := hc.Do(req)
|
r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -174,7 +208,10 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
if code := r.StatusCode; code < 200 || code > 299 {
|
if code := r.StatusCode; code < 200 || code > 299 {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
|
return nil, &RetrieveError{
|
||||||
|
Response: r,
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var token *Token
|
var token *Token
|
||||||
|
@ -221,5 +258,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||||
if token.RefreshToken == "" {
|
if token.RefreshToken == "" {
|
||||||
token.RefreshToken = v.Get("refresh_token")
|
token.RefreshToken = v.Get("refresh_token")
|
||||||
}
|
}
|
||||||
|
if token.AccessToken == "" {
|
||||||
|
return token, errors.New("oauth2: server response missing access_token")
|
||||||
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RetrieveError struct {
|
||||||
|
Response *http.Response
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RetrieveError) Error() string {
|
||||||
|
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
|
||||||
|
}
|
||||||
|
|
|
@ -2,13 +2,11 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
// Package internal contains support packages for oauth2 package.
|
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPClient is the context key to use with golang.org/x/net/context's
|
// HTTPClient is the context key to use with golang.org/x/net/context's
|
||||||
|
@ -20,50 +18,16 @@ var HTTPClient ContextKey
|
||||||
// because nobody else can create a ContextKey, being unexported.
|
// because nobody else can create a ContextKey, being unexported.
|
||||||
type ContextKey struct{}
|
type ContextKey struct{}
|
||||||
|
|
||||||
// ContextClientFunc is a func which tries to return an *http.Client
|
var appengineClientHook func(context.Context) *http.Client
|
||||||
// given a Context value. If it returns an error, the search stops
|
|
||||||
// with that error. If it returns (nil, nil), the search continues
|
|
||||||
// down the list of registered funcs.
|
|
||||||
type ContextClientFunc func(context.Context) (*http.Client, error)
|
|
||||||
|
|
||||||
var contextClientFuncs []ContextClientFunc
|
func ContextClient(ctx context.Context) *http.Client {
|
||||||
|
|
||||||
func RegisterContextClientFunc(fn ContextClientFunc) {
|
|
||||||
contextClientFuncs = append(contextClientFuncs, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ContextClient(ctx context.Context) (*http.Client, error) {
|
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
|
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
|
||||||
return hc, nil
|
return hc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, fn := range contextClientFuncs {
|
if appengineClientHook != nil {
|
||||||
c, err := fn(ctx)
|
return appengineClientHook(ctx)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if c != nil {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return http.DefaultClient, nil
|
return http.DefaultClient
|
||||||
}
|
|
||||||
|
|
||||||
func ContextTransport(ctx context.Context) http.RoundTripper {
|
|
||||||
hc, err := ContextClient(ctx)
|
|
||||||
// This is a rare error case (somebody using nil on App Engine).
|
|
||||||
if err != nil {
|
|
||||||
return ErrorTransport{err}
|
|
||||||
}
|
|
||||||
return hc.Transport
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorTransport returns the specified error on RoundTrip.
|
|
||||||
// This RoundTripper should be used in rare error cases where
|
|
||||||
// error handling can be postponed to response handling time.
|
|
||||||
type ErrorTransport struct{ Err error }
|
|
||||||
|
|
||||||
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
|
||||||
return nil, t.Err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,19 +3,20 @@
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
// Package oauth2 provides support for making
|
// Package oauth2 provides support for making
|
||||||
// OAuth2 authorized and authenticated HTTP requests.
|
// OAuth2 authorized and authenticated HTTP requests,
|
||||||
|
// as specified in RFC 6749.
|
||||||
// It can additionally grant authorization with Bearer JWT.
|
// It can additionally grant authorization with Bearer JWT.
|
||||||
package oauth2 // import "golang.org/x/oauth2"
|
package oauth2 // import "golang.org/x/oauth2"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
"golang.org/x/oauth2/internal"
|
"golang.org/x/oauth2/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -117,21 +118,30 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
|
||||||
// that asks for permissions for the required scopes explicitly.
|
// that asks for permissions for the required scopes explicitly.
|
||||||
//
|
//
|
||||||
// State is a token to protect the user from CSRF attacks. You must
|
// State is a token to protect the user from CSRF attacks. You must
|
||||||
// always provide a non-zero string and validate that it matches the
|
// always provide a non-empty string and validate that it matches the
|
||||||
// the state query parameter on your redirect callback.
|
// the state query parameter on your redirect callback.
|
||||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||||
//
|
//
|
||||||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
||||||
// as ApprovalForce.
|
// as ApprovalForce.
|
||||||
|
// It can also be used to pass the PKCE challange.
|
||||||
|
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
|
||||||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
buf.WriteString(c.Endpoint.AuthURL)
|
buf.WriteString(c.Endpoint.AuthURL)
|
||||||
v := url.Values{
|
v := url.Values{
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"client_id": {c.ClientID},
|
"client_id": {c.ClientID},
|
||||||
"redirect_uri": internal.CondVal(c.RedirectURL),
|
}
|
||||||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")),
|
if c.RedirectURL != "" {
|
||||||
"state": internal.CondVal(state),
|
v.Set("redirect_uri", c.RedirectURL)
|
||||||
|
}
|
||||||
|
if len(c.Scopes) > 0 {
|
||||||
|
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||||
|
}
|
||||||
|
if state != "" {
|
||||||
|
// TODO(light): Docs say never to omit state; don't allow empty.
|
||||||
|
v.Set("state", state)
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt.setValue(v)
|
opt.setValue(v)
|
||||||
|
@ -157,12 +167,15 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||||
// The HTTP client to use is derived from the context.
|
// The HTTP client to use is derived from the context.
|
||||||
// If nil, http.DefaultClient is used.
|
// If nil, http.DefaultClient is used.
|
||||||
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
|
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
|
||||||
return retrieveToken(ctx, c, url.Values{
|
v := url.Values{
|
||||||
"grant_type": {"password"},
|
"grant_type": {"password"},
|
||||||
"username": {username},
|
"username": {username},
|
||||||
"password": {password},
|
"password": {password},
|
||||||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")),
|
}
|
||||||
})
|
if len(c.Scopes) > 0 {
|
||||||
|
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||||
|
}
|
||||||
|
return retrieveToken(ctx, c, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange converts an authorization code into a token.
|
// Exchange converts an authorization code into a token.
|
||||||
|
@ -175,13 +188,21 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
|
||||||
//
|
//
|
||||||
// The code will be in the *http.Request.FormValue("code"). Before
|
// The code will be in the *http.Request.FormValue("code"). Before
|
||||||
// calling Exchange, be sure to validate FormValue("state").
|
// calling Exchange, be sure to validate FormValue("state").
|
||||||
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) {
|
//
|
||||||
return retrieveToken(ctx, c, url.Values{
|
// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
|
||||||
"grant_type": {"authorization_code"},
|
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
|
||||||
"code": {code},
|
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
|
||||||
"redirect_uri": internal.CondVal(c.RedirectURL),
|
v := url.Values{
|
||||||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")),
|
"grant_type": {"authorization_code"},
|
||||||
})
|
"code": {code},
|
||||||
|
}
|
||||||
|
if c.RedirectURL != "" {
|
||||||
|
v.Set("redirect_uri", c.RedirectURL)
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt.setValue(v)
|
||||||
|
}
|
||||||
|
return retrieveToken(ctx, c, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client returns an HTTP client using the provided token.
|
// Client returns an HTTP client using the provided token.
|
||||||
|
@ -292,20 +313,20 @@ var HTTPClient internal.ContextKey
|
||||||
// NewClient creates an *http.Client from a Context and TokenSource.
|
// NewClient creates an *http.Client from a Context and TokenSource.
|
||||||
// The returned client is not valid beyond the lifetime of the context.
|
// The returned client is not valid beyond the lifetime of the context.
|
||||||
//
|
//
|
||||||
|
// Note that if a custom *http.Client is provided via the Context it
|
||||||
|
// is used only for token acquisition and is not used to configure the
|
||||||
|
// *http.Client returned from NewClient.
|
||||||
|
//
|
||||||
// As a special case, if src is nil, a non-OAuth2 client is returned
|
// As a special case, if src is nil, a non-OAuth2 client is returned
|
||||||
// using the provided context. This exists to support related OAuth2
|
// using the provided context. This exists to support related OAuth2
|
||||||
// packages.
|
// packages.
|
||||||
func NewClient(ctx context.Context, src TokenSource) *http.Client {
|
func NewClient(ctx context.Context, src TokenSource) *http.Client {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
c, err := internal.ContextClient(ctx)
|
return internal.ContextClient(ctx)
|
||||||
if err != nil {
|
|
||||||
return &http.Client{Transport: internal.ErrorTransport{Err: err}}
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
}
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: &Transport{
|
Transport: &Transport{
|
||||||
Base: internal.ContextTransport(ctx),
|
Base: internal.ContextClient(ctx).Transport,
|
||||||
Source: ReuseTokenSource(nil, src),
|
Source: ReuseTokenSource(nil, src),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,13 +5,14 @@
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
"golang.org/x/oauth2/internal"
|
"golang.org/x/oauth2/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ import (
|
||||||
// expirations due to client-server time mismatches.
|
// expirations due to client-server time mismatches.
|
||||||
const expiryDelta = 10 * time.Second
|
const expiryDelta = 10 * time.Second
|
||||||
|
|
||||||
// Token represents the crendentials used to authorize
|
// Token represents the credentials used to authorize
|
||||||
// the requests to access protected resources on the OAuth 2.0
|
// the requests to access protected resources on the OAuth 2.0
|
||||||
// provider's backend.
|
// provider's backend.
|
||||||
//
|
//
|
||||||
|
@ -123,7 +124,7 @@ func (t *Token) expired() bool {
|
||||||
if t.Expiry.IsZero() {
|
if t.Expiry.IsZero() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return t.Expiry.Add(-expiryDelta).Before(time.Now())
|
return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
|
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
|
||||||
|
@ -152,7 +153,23 @@ func tokenFromInternal(t *internal.Token) *Token {
|
||||||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
||||||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v)
|
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if rErr, ok := err.(*internal.RetrieveError); ok {
|
||||||
|
return nil, (*RetrieveError)(rErr)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return tokenFromInternal(tk), nil
|
return tokenFromInternal(tk), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RetrieveError is the error returned when the token endpoint returns a
|
||||||
|
// non-2XX HTTP status code.
|
||||||
|
type RetrieveError struct {
|
||||||
|
Response *http.Response
|
||||||
|
// Body is the body that was consumed by reading Response.Body.
|
||||||
|
// It may be truncated.
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RetrieveError) Error() string {
|
||||||
|
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
|
||||||
|
}
|
||||||
|
|
|
@ -31,9 +31,17 @@ type Transport struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip authorizes and authenticates the request with an
|
// RoundTrip authorizes and authenticates the request with an
|
||||||
// access token. If no token exists or token is expired,
|
// access token from Transport's Source.
|
||||||
// tries to refresh/fetch a new token.
|
|
||||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
reqBodyClosed := false
|
||||||
|
if req.Body != nil {
|
||||||
|
defer func() {
|
||||||
|
if !reqBodyClosed {
|
||||||
|
req.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
if t.Source == nil {
|
if t.Source == nil {
|
||||||
return nil, errors.New("oauth2: Transport's Source is nil")
|
return nil, errors.New("oauth2: Transport's Source is nil")
|
||||||
}
|
}
|
||||||
|
@ -46,6 +54,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
token.SetAuthHeader(req2)
|
token.SetAuthHeader(req2)
|
||||||
t.setModReq(req, req2)
|
t.setModReq(req, req2)
|
||||||
res, err := t.base().RoundTrip(req2)
|
res, err := t.base().RoundTrip(req2)
|
||||||
|
|
||||||
|
// req.Body is assumed to have been closed by the base RoundTripper.
|
||||||
|
reqBodyClosed = true
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.setModReq(req, nil)
|
t.setModReq(req, nil)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Reference in New Issue