Merge pull request #1491 from influxdata/feature/godep
Update our go vendoring to use deppull/1458/head^2
commit
f9602792af
|
@ -10,6 +10,7 @@
|
|||
### Features
|
||||
1. [#1477](https://github.com/influxdata/chronograf/pull/1477): Add ability to log alerts
|
||||
1. [#1474](https://github.com/influxdata/chronograf/pull/1474): Change behavior of template variable autocomplete to filter by exact match from front of string
|
||||
1. [#1491](https://github.com/influxdata/chronograf/pull/1491): Update go vendoring to dep and committed vendor directory
|
||||
|
||||
### UI Improvements
|
||||
1. [#1451](https://github.com/influxdata/chronograf/pull/1451): Refactor scrollbars to support non-webkit browsers
|
||||
|
|
19
Godeps
19
Godeps
|
@ -1,19 +0,0 @@
|
|||
github.com/NYTimes/gziphandler 6710af535839f57c687b62c4c23d649f9545d885
|
||||
github.com/Sirupsen/logrus 3ec0642a7fb6488f65b06f9040adc67e3990296a
|
||||
github.com/boltdb/bolt 5cc10bbbc5c141029940133bb33c9e969512a698
|
||||
github.com/bouk/httprouter ee8b3818a7f51fbc94cc709b5744b52c2c725e91
|
||||
github.com/dgrijalva/jwt-go 24c63f56522a87ec5339cc3567883f1039378fdb
|
||||
github.com/elazarl/go-bindata-assetfs 9a6736ed45b44bf3835afeebb3034b57ed329f3e
|
||||
github.com/gogo/protobuf 6abcf94fd4c97dcb423fdafd42fe9f96ca7e421b
|
||||
github.com/google/go-github 1bc362c7737e51014af7299e016444b654095ad9
|
||||
github.com/google/go-querystring 9235644dd9e52eeae6fa48efd539fdc351a0af53
|
||||
github.com/influxdata/influxdb af72d9b0e4ebe95be30e89b160f43eabaf0529ed
|
||||
github.com/influxdata/kapacitor 5408057e5a3493d3b5bd38d5d535ea45b587f8ff
|
||||
github.com/influxdata/usage-client 6d3895376368aa52a3a81d2a16e90f0f52371967
|
||||
github.com/jessevdk/go-flags 4cc2832a6e6d1d3b815e2b9d544b2a4dfb3ce8fa
|
||||
github.com/satori/go.uuid b061729afc07e77a8aa4fad0a2fd840958f1942a
|
||||
github.com/sergi/go-diff 1d28411638c1e67fe1930830df207bef72496ae9
|
||||
github.com/tylerb/graceful 50a48b6e73fcc75b45e22c05b79629a67c79e938
|
||||
golang.org/x/net 749a502dd1eaf3e5bfd4f8956748c502357c0bbe
|
||||
golang.org/x/oauth2 1e695b1c8febf17aad3bfa7bf0a819ef94b98ad5
|
||||
google.golang.org/api bc20c61134e1d25265dd60049f5735381e79b631
|
|
@ -0,0 +1,132 @@
|
|||
memo = "bac138180cd86a0ae604cd3aa7b6ba300673478c880882bd58a4bd7f8bff518d"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/NYTimes/gziphandler"
|
||||
packages = ["."]
|
||||
revision = "6710af535839f57c687b62c4c23d649f9545d885"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/Sirupsen/logrus"
|
||||
packages = ["."]
|
||||
revision = "3ec0642a7fb6488f65b06f9040adc67e3990296a"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/boltdb/bolt"
|
||||
packages = ["."]
|
||||
revision = "5cc10bbbc5c141029940133bb33c9e969512a698"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/bouk/httprouter"
|
||||
packages = ["."]
|
||||
revision = "ee8b3818a7f51fbc94cc709b5744b52c2c725e91"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/dgrijalva/jwt-go"
|
||||
packages = ["."]
|
||||
revision = "24c63f56522a87ec5339cc3567883f1039378fdb"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/dustin/go-humanize"
|
||||
packages = ["."]
|
||||
revision = "259d2a102b871d17f30e3cd9881a642961a1e486"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/elazarl/go-bindata-assetfs"
|
||||
packages = ["."]
|
||||
revision = "9a6736ed45b44bf3835afeebb3034b57ed329f3e"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/gogo/protobuf"
|
||||
packages = ["gogoproto","jsonpb","plugin/compare","plugin/defaultcheck","plugin/description","plugin/embedcheck","plugin/enumstringer","plugin/equal","plugin/face","plugin/gostring","plugin/marshalto","plugin/oneofcheck","plugin/populate","plugin/size","plugin/stringer","plugin/testgen","plugin/union","plugin/unmarshal","proto","protoc-gen-gogo","protoc-gen-gogo/descriptor","protoc-gen-gogo/generator","protoc-gen-gogo/grpc","protoc-gen-gogo/plugin","vanity","vanity/command"]
|
||||
revision = "6abcf94fd4c97dcb423fdafd42fe9f96ca7e421b"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/golang/protobuf"
|
||||
packages = ["proto"]
|
||||
revision = "8ee79997227bf9b34611aee7946ae64735e6fd93"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/google/go-github"
|
||||
packages = ["github"]
|
||||
revision = "1bc362c7737e51014af7299e016444b654095ad9"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/google/go-querystring"
|
||||
packages = ["query"]
|
||||
revision = "9235644dd9e52eeae6fa48efd539fdc351a0af53"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/influxdata/influxdb"
|
||||
packages = ["influxql","influxql/internal","influxql/neldermead","models","pkg/escape"]
|
||||
revision = "af72d9b0e4ebe95be30e89b160f43eabaf0529ed"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/influxdata/kapacitor"
|
||||
packages = ["client/v1","influxdb","models","pipeline","services/k8s/client","tick","tick/ast","tick/stateful","udf"]
|
||||
revision = "5408057e5a3493d3b5bd38d5d535ea45b587f8ff"
|
||||
version = "v1.2.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/influxdata/usage-client"
|
||||
packages = ["v1"]
|
||||
revision = "6d3895376368aa52a3a81d2a16e90f0f52371967"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/jessevdk/go-flags"
|
||||
packages = ["."]
|
||||
revision = "4cc2832a6e6d1d3b815e2b9d544b2a4dfb3ce8fa"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/jteeuwen/go-bindata"
|
||||
packages = ["."]
|
||||
revision = "a0ff2567cfb70903282db057e799fd826784d41d"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/pkg/errors"
|
||||
packages = ["."]
|
||||
revision = "ff09b135c25aae272398c51a07235b90a75aa4f0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/satori/go.uuid"
|
||||
packages = ["."]
|
||||
revision = "b061729afc07e77a8aa4fad0a2fd840958f1942a"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/sergi/go-diff"
|
||||
packages = ["diffmatchpatch"]
|
||||
revision = "1d28411638c1e67fe1930830df207bef72496ae9"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/tylerb/graceful"
|
||||
packages = ["."]
|
||||
revision = "50a48b6e73fcc75b45e22c05b79629a67c79e938"
|
||||
version = "v1.2.13"
|
||||
|
||||
[[projects]]
|
||||
name = "golang.org/x/net"
|
||||
packages = ["context","context/ctxhttp"]
|
||||
revision = "749a502dd1eaf3e5bfd4f8956748c502357c0bbe"
|
||||
|
||||
[[projects]]
|
||||
name = "golang.org/x/oauth2"
|
||||
packages = [".","github","heroku","internal"]
|
||||
revision = "1e695b1c8febf17aad3bfa7bf0a819ef94b98ad5"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/sys"
|
||||
packages = ["unix"]
|
||||
revision = "f3918c30c5c2cb527c0b071a27c35120a6c0719a"
|
||||
|
||||
[[projects]]
|
||||
name = "google.golang.org/api"
|
||||
packages = ["gensupport","googleapi","googleapi/internal/uritemplates","oauth2/v2"]
|
||||
revision = "bc20c61134e1d25265dd60049f5735381e79b631"
|
||||
|
||||
[[projects]]
|
||||
name = "google.golang.org/appengine"
|
||||
packages = ["internal","internal/base","internal/datastore","internal/log","internal/remote_api","internal/urlfetch","urlfetch"]
|
||||
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
|
||||
version = "v1.0.0"
|
|
@ -0,0 +1,77 @@
|
|||
required = ["github.com/jteeuwen/go-bindata","github.com/gogo/protobuf/proto","github.com/gogo/protobuf/jsonpb","github.com/gogo/protobuf/protoc-gen-gogo","github.com/gogo/protobuf/gogoproto"]
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/NYTimes/gziphandler"
|
||||
revision = "6710af535839f57c687b62c4c23d649f9545d885"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/Sirupsen/logrus"
|
||||
revision = "3ec0642a7fb6488f65b06f9040adc67e3990296a"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/boltdb/bolt"
|
||||
revision = "5cc10bbbc5c141029940133bb33c9e969512a698"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/bouk/httprouter"
|
||||
revision = "ee8b3818a7f51fbc94cc709b5744b52c2c725e91"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/dgrijalva/jwt-go"
|
||||
revision = "24c63f56522a87ec5339cc3567883f1039378fdb"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/elazarl/go-bindata-assetfs"
|
||||
revision = "9a6736ed45b44bf3835afeebb3034b57ed329f3e"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/gogo/protobuf"
|
||||
revision = "6abcf94fd4c97dcb423fdafd42fe9f96ca7e421b"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/google/go-github"
|
||||
revision = "1bc362c7737e51014af7299e016444b654095ad9"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/influxdata/influxdb"
|
||||
revision = "af72d9b0e4ebe95be30e89b160f43eabaf0529ed"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/influxdata/kapacitor"
|
||||
version = "^1.2.0"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/influxdata/usage-client"
|
||||
revision = "6d3895376368aa52a3a81d2a16e90f0f52371967"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/jessevdk/go-flags"
|
||||
revision = "4cc2832a6e6d1d3b815e2b9d544b2a4dfb3ce8fa"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/jteeuwen/go-bindata"
|
||||
revision = "a0ff2567cfb70903282db057e799fd826784d41d"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/satori/go.uuid"
|
||||
revision = "b061729afc07e77a8aa4fad0a2fd840958f1942a"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/sergi/go-diff"
|
||||
revision = "1d28411638c1e67fe1930830df207bef72496ae9"
|
||||
|
||||
[[dependencies]]
|
||||
name = "github.com/tylerb/graceful"
|
||||
version = "^1.2.13"
|
||||
|
||||
[[dependencies]]
|
||||
name = "golang.org/x/net"
|
||||
revision = "749a502dd1eaf3e5bfd4f8956748c502357c0bbe"
|
||||
|
||||
[[dependencies]]
|
||||
name = "golang.org/x/oauth2"
|
||||
revision = "1e695b1c8febf17aad3bfa7bf0a819ef94b98ad5"
|
||||
|
||||
[[dependencies]]
|
||||
name = "google.golang.org/api"
|
||||
revision = "bc20c61134e1d25265dd60049f5735381e79b631"
|
21
Makefile
21
Makefile
|
@ -2,11 +2,10 @@
|
|||
|
||||
VERSION ?= $(shell git describe --always --tags)
|
||||
COMMIT ?= $(shell git rev-parse --short=8 HEAD)
|
||||
GDM := $(shell command -v gdm 2> /dev/null)
|
||||
GOBINDATA := $(shell go list -f {{.Root}} github.com/jteeuwen/go-bindata 2> /dev/null)
|
||||
YARN := $(shell command -v yarn 2> /dev/null)
|
||||
|
||||
SOURCES := $(shell find . -name '*.go' ! -name '*_gen.go')
|
||||
SOURCES := $(shell find . -name '*.go' ! -name '*_gen.go' -not -path "./vendor/*" )
|
||||
UISOURCES := $(shell find ui -type f -not \( -path ui/build/\* -o -path ui/node_modules/\* -prune \) )
|
||||
|
||||
LDFLAGS=-ldflags "-s -X main.version=${VERSION} -X main.commit=${COMMIT}"
|
||||
|
@ -70,16 +69,11 @@ canned/bin_gen.go: canned/*.json
|
|||
|
||||
dep: .jsdep .godep
|
||||
|
||||
.godep: Godeps
|
||||
ifndef GDM
|
||||
@echo "Installing GDM"
|
||||
go get github.com/sparrc/gdm
|
||||
endif
|
||||
.godep:
|
||||
ifndef GOBINDATA
|
||||
@echo "Installing go-bindata"
|
||||
go get -u github.com/jteeuwen/go-bindata/...
|
||||
endif
|
||||
gdm restore
|
||||
@touch .godep
|
||||
|
||||
.jsdep: ui/yarn.lock
|
||||
|
@ -90,16 +84,18 @@ else
|
|||
@touch .jsdep
|
||||
endif
|
||||
|
||||
gen: bolt/internal/internal.proto
|
||||
gen: internal.pb.go
|
||||
|
||||
internal.pb.go: bolt/internal/internal.proto
|
||||
go generate -x ./bolt/internal
|
||||
|
||||
test: jstest gotest gotestrace
|
||||
|
||||
gotest:
|
||||
go test ./...
|
||||
go test `go list ./... | grep -v /vendor/`
|
||||
|
||||
gotestrace:
|
||||
go test -race ./...
|
||||
go test -race `go list ./... | grep -v /vendor/`
|
||||
|
||||
jstest:
|
||||
cd ui && npm test
|
||||
|
@ -117,8 +113,5 @@ clean:
|
|||
rm -f dist/dist_gen.go canned/bin_gen.go server/swagger_gen.go
|
||||
@rm -f .godep .jsdep .jssrc .dev-jssrc .bindata
|
||||
|
||||
continuous:
|
||||
while true; do if fswatch -e "\.git" -r --one-event .; then echo "#-> Starting build: `date`"; make dev; pkill -9 chronograf; make run-dev & echo "#-> Build complete."; fi; sleep 0.5; done
|
||||
|
||||
ctags:
|
||||
ctags -R --languages="Go" --exclude=.git --exclude=ui .
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
*.swp
|
|
@ -0,0 +1,5 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.7
|
||||
- tip
|
|
@ -0,0 +1,75 @@
|
|||
---
|
||||
layout: code-of-conduct
|
||||
version: v1.0
|
||||
---
|
||||
|
||||
This code of conduct outlines our expectations for participants within the **NYTimes/gziphandler** community, as well as steps to reporting unacceptable behavior. We are committed to providing a welcoming and inspiring community for all and expect our code of conduct to be honored. Anyone who violates this code of conduct may be banned from the community.
|
||||
|
||||
Our open source community strives to:
|
||||
|
||||
* **Be friendly and patient.**
|
||||
* **Be welcoming**: We strive to be a community that welcomes and supports people of all backgrounds and identities. This includes, but is not limited to members of any race, ethnicity, culture, national origin, colour, immigration status, social and economic class, educational level, sex, sexual orientation, gender identity and expression, age, size, family status, political belief, religion, and mental and physical ability.
|
||||
* **Be considerate**: Your work will be used by other people, and you in turn will depend on the work of others. Any decision you take will affect users and colleagues, and you should take those consequences into account when making decisions. Remember that we're a world-wide community, so you might not be communicating in someone else's primary language.
|
||||
* **Be respectful**: Not all of us will agree all the time, but disagreement is no excuse for poor behavior and poor manners. We might all experience some frustration now and then, but we cannot allow that frustration to turn into a personal attack. It’s important to remember that a community where people feel uncomfortable or threatened is not a productive one.
|
||||
* **Be careful in the words that we choose**: we are a community of professionals, and we conduct ourselves professionally. Be kind to others. Do not insult or put down other participants. Harassment and other exclusionary behavior aren't acceptable.
|
||||
* **Try to understand why we disagree**: Disagreements, both social and technical, happen all the time. It is important that we resolve disagreements and differing views constructively. Remember that we’re different. The strength of our community comes from its diversity, people from a wide range of backgrounds. Different people have different perspectives on issues. Being unable to understand why someone holds a viewpoint doesn’t mean that they’re wrong. Don’t forget that it is human to err and blaming each other doesn’t get us anywhere. Instead, focus on helping to resolve issues and learning from mistakes.
|
||||
|
||||
## Definitions
|
||||
|
||||
Harassment includes, but is not limited to:
|
||||
|
||||
- Offensive comments related to gender, gender identity and expression, sexual orientation, disability, mental illness, neuro(a)typicality, physical appearance, body size, race, age, regional discrimination, political or religious affiliation
|
||||
- Unwelcome comments regarding a person’s lifestyle choices and practices, including those related to food, health, parenting, drugs, and employment
|
||||
- Deliberate misgendering. This includes deadnaming or persistently using a pronoun that does not correctly reflect a person's gender identity. You must address people by the name they give you when not addressing them by their username or handle
|
||||
- Physical contact and simulated physical contact (eg, textual descriptions like “*hug*” or “*backrub*”) without consent or after a request to stop
|
||||
- Threats of violence, both physical and psychological
|
||||
- Incitement of violence towards any individual, including encouraging a person to commit suicide or to engage in self-harm
|
||||
- Deliberate intimidation
|
||||
- Stalking or following
|
||||
- Harassing photography or recording, including logging online activity for harassment purposes
|
||||
- Sustained disruption of discussion
|
||||
- Unwelcome sexual attention, including gratuitous or off-topic sexual images or behaviour
|
||||
- Pattern of inappropriate social contact, such as requesting/assuming inappropriate levels of intimacy with others
|
||||
- Continued one-on-one communication after requests to cease
|
||||
- Deliberate “outing” of any aspect of a person’s identity without their consent except as necessary to protect others from intentional abuse
|
||||
- Publication of non-harassing private communication
|
||||
|
||||
Our open source community prioritizes marginalized people’s safety over privileged people’s comfort. We will not act on complaints regarding:
|
||||
|
||||
- ‘Reverse’ -isms, including ‘reverse racism,’ ‘reverse sexism,’ and ‘cisphobia’
|
||||
- Reasonable communication of boundaries, such as “leave me alone,” “go away,” or “I’m not discussing this with you”
|
||||
- Refusal to explain or debate social justice concepts
|
||||
- Communicating in a ‘tone’ you don’t find congenial
|
||||
- Criticizing racist, sexist, cissexist, or otherwise oppressive behavior or assumptions
|
||||
|
||||
|
||||
### Diversity Statement
|
||||
|
||||
We encourage everyone to participate and are committed to building a community for all. Although we will fail at times, we seek to treat everyone both as fairly and equally as possible. Whenever a participant has made a mistake, we expect them to take responsibility for it. If someone has been harmed or offended, it is our responsibility to listen carefully and respectfully, and do our best to right the wrong.
|
||||
|
||||
Although this list cannot be exhaustive, we explicitly honor diversity in age, gender, gender identity or expression, culture, ethnicity, language, national origin, political beliefs, profession, race, religion, sexual orientation, socioeconomic status, and technical ability. We will not tolerate discrimination based on any of the protected
|
||||
characteristics above, including participants with disabilities.
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
If you experience or witness unacceptable behavior—or have any other concerns—please report it by contacting us via **code@nytimes.com**. All reports will be handled with discretion. In your report please include:
|
||||
|
||||
- Your contact information.
|
||||
- Names (real, nicknames, or pseudonyms) of any individuals involved. If there are additional witnesses, please
|
||||
include them as well. Your account of what occurred, and if you believe the incident is ongoing. If there is a publicly available record (e.g. a mailing list archive or a public IRC logger), please include a link.
|
||||
- Any additional information that may be helpful.
|
||||
|
||||
After filing a report, a representative will contact you personally, review the incident, follow up with any additional questions, and make a decision as to how to respond. If the person who is harassing you is part of the response team, they will recuse themselves from handling your incident. If the complaint originates from a member of the response team, it will be handled by a different member of the response team. We will respect confidentiality requests for the purpose of protecting victims of abuse.
|
||||
|
||||
### Attribution & Acknowledgements
|
||||
|
||||
We all stand on the shoulders of giants across many open source communities. We'd like to thank the communities and projects that established code of conducts and diversity statements as our inspiration:
|
||||
|
||||
* [Django](https://www.djangoproject.com/conduct/reporting/)
|
||||
* [Python](https://www.python.org/community/diversity/)
|
||||
* [Ubuntu](http://www.ubuntu.com/about/about-ubuntu/conduct)
|
||||
* [Contributor Covenant](http://contributor-covenant.org/)
|
||||
* [Geek Feminism](http://geekfeminism.org/about/code-of-conduct/)
|
||||
* [Citizen Code of Conduct](http://citizencodeofconduct.org/)
|
||||
|
||||
This Code of Conduct was based on https://github.com/todogroup/opencodeofconduct
|
|
@ -0,0 +1,30 @@
|
|||
# Contributing to NYTimes/gziphandler
|
||||
|
||||
This is an open source project started by handful of developers at The New York Times and open to the entire Go community.
|
||||
|
||||
We really appreciate your help!
|
||||
|
||||
## Filing issues
|
||||
|
||||
When filing an issue, make sure to answer these five questions:
|
||||
|
||||
1. What version of Go are you using (`go version`)?
|
||||
2. What operating system and processor architecture are you using?
|
||||
3. What did you do?
|
||||
4. What did you expect to see?
|
||||
5. What did you see instead?
|
||||
|
||||
## Contributing code
|
||||
|
||||
Before submitting changes, please follow these guidelines:
|
||||
|
||||
1. Check the open issues and pull requests for existing discussions.
|
||||
2. Open an issue to discuss a new feature.
|
||||
3. Write tests.
|
||||
4. Make sure code follows the ['Go Code Review Comments'](https://github.com/golang/go/wiki/CodeReviewComments).
|
||||
5. Make sure your changes pass `go test`.
|
||||
6. Make sure the entire test suite passes locally and on Travis CI.
|
||||
7. Open a Pull Request.
|
||||
8. [Squash your commits](http://gitready.com/advanced/2009/02/10/squashing-commits-with-rebase.html) after receiving feedback and add a [great commit message](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html).
|
||||
|
||||
Unless otherwise noted, the gziphandler source files are distributed under the Apache 2.0-style license found in the LICENSE.md file.
|
|
@ -0,0 +1,13 @@
|
|||
Copyright (c) 2015 The New York Times Company
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this library except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,52 @@
|
|||
Gzip Handler
|
||||
============
|
||||
|
||||
This is a tiny Go package which wraps HTTP handlers to transparently gzip the
|
||||
response body, for clients which support it. Although it's usually simpler to
|
||||
leave that to a reverse proxy (like nginx or Varnish), this package is useful
|
||||
when that's undesirable.
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
Call `GzipHandler` with any handler (an object which implements the
|
||||
`http.Handler` interface), and it'll return a new handler which gzips the
|
||||
response. For example:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
)
|
||||
|
||||
func main() {
|
||||
withoutGz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
io.WriteString(w, "Hello, World")
|
||||
})
|
||||
|
||||
withGz := gziphandler.GzipHandler(withoutGz)
|
||||
|
||||
http.Handle("/", withGz)
|
||||
http.ListenAndServe("0.0.0.0:8000", nil)
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Documentation
|
||||
|
||||
The docs can be found at [godoc.org] [docs], as usual.
|
||||
|
||||
|
||||
## License
|
||||
|
||||
[Apache 2.0] [license].
|
||||
|
||||
|
||||
|
||||
|
||||
[docs]: https://godoc.org/github.com/nytimes/gziphandler
|
||||
[license]: https://github.com/nytimes/gziphandler/blob/master/LICENSE.md
|
|
@ -0,0 +1,260 @@
|
|||
package gziphandler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
vary = "Vary"
|
||||
acceptEncoding = "Accept-Encoding"
|
||||
contentEncoding = "Content-Encoding"
|
||||
contentType = "Content-Type"
|
||||
contentLength = "Content-Length"
|
||||
)
|
||||
|
||||
type codings map[string]float64
|
||||
|
||||
// The default qvalue to assign to an encoding if no explicit qvalue is set.
|
||||
// This is actually kind of ambiguous in RFC 2616, so hopefully it's correct.
|
||||
// The examples seem to indicate that it is.
|
||||
const DEFAULT_QVALUE = 1.0
|
||||
|
||||
// gzipWriterPools stores a sync.Pool for each compression level for reuse of
|
||||
// gzip.Writers. Use poolIndex to covert a compression level to an index into
|
||||
// gzipWriterPools.
|
||||
var gzipWriterPools [gzip.BestCompression - gzip.BestSpeed + 2]*sync.Pool
|
||||
|
||||
func init() {
|
||||
for i := gzip.BestSpeed; i <= gzip.BestCompression; i++ {
|
||||
addLevelPool(i)
|
||||
}
|
||||
addLevelPool(gzip.DefaultCompression)
|
||||
}
|
||||
|
||||
// poolIndex maps a compression level to its index into gzipWriterPools. It
|
||||
// assumes that level is a valid gzip compression level.
|
||||
func poolIndex(level int) int {
|
||||
// gzip.DefaultCompression == -1, so we need to treat it special.
|
||||
if level == gzip.DefaultCompression {
|
||||
return gzip.BestCompression - gzip.BestSpeed + 1
|
||||
}
|
||||
return level - gzip.BestSpeed
|
||||
}
|
||||
|
||||
func addLevelPool(level int) {
|
||||
gzipWriterPools[poolIndex(level)] = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
// NewWriterLevel only returns error on a bad level, we are guaranteeing
|
||||
// that this will be a valid level so it is okay to ignore the returned
|
||||
// error.
|
||||
w, _ := gzip.NewWriterLevel(nil, level)
|
||||
return w
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GzipResponseWriter provides an http.ResponseWriter interface, which gzips
|
||||
// bytes before writing them to the underlying response. This doesn't close the
|
||||
// writers, so don't forget to do that.
|
||||
type GzipResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
index int // Index for gzipWriterPools.
|
||||
gw *gzip.Writer
|
||||
}
|
||||
|
||||
// Write appends data to the gzip writer.
|
||||
func (w *GzipResponseWriter) Write(b []byte) (int, error) {
|
||||
// Lazily create the gzip.Writer, this allows empty bodies to be actually
|
||||
// empty, for example in the case of status code 204 (no content).
|
||||
if w.gw == nil {
|
||||
w.init()
|
||||
}
|
||||
|
||||
if _, ok := w.Header()[contentType]; !ok {
|
||||
// If content type is not set, infer it from the uncompressed body.
|
||||
w.Header().Set(contentType, http.DetectContentType(b))
|
||||
}
|
||||
return w.gw.Write(b)
|
||||
}
|
||||
|
||||
// WriteHeader will check if the gzip writer needs to be lazily initiated and
|
||||
// then pass the code along to the underlying ResponseWriter.
|
||||
func (w *GzipResponseWriter) WriteHeader(code int) {
|
||||
if w.gw == nil &&
|
||||
code != http.StatusNotModified && code != http.StatusNoContent {
|
||||
w.init()
|
||||
}
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// init graps a new gzip writer from the gzipWriterPool and writes the correct
|
||||
// content encoding header.
|
||||
func (w *GzipResponseWriter) init() {
|
||||
// Bytes written during ServeHTTP are redirected to this gzip writer
|
||||
// before being written to the underlying response.
|
||||
gzw := gzipWriterPools[w.index].Get().(*gzip.Writer)
|
||||
gzw.Reset(w.ResponseWriter)
|
||||
w.gw = gzw
|
||||
w.ResponseWriter.Header().Set(contentEncoding, "gzip")
|
||||
// if the Content-Length is already set, then calls to Write on gzip
|
||||
// will fail to set the Content-Length header since its already set
|
||||
// See: https://github.com/golang/go/issues/14975
|
||||
w.ResponseWriter.Header().Del(contentLength)
|
||||
}
|
||||
|
||||
// Close will close the gzip.Writer and will put it back in the gzipWriterPool.
|
||||
func (w *GzipResponseWriter) Close() error {
|
||||
if w.gw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := w.gw.Close()
|
||||
gzipWriterPools[w.index].Put(w.gw)
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush flushes the underlying *gzip.Writer and then the underlying
|
||||
// http.ResponseWriter if it is an http.Flusher. This makes GzipResponseWriter
|
||||
// an http.Flusher.
|
||||
func (w *GzipResponseWriter) Flush() {
|
||||
if w.gw != nil {
|
||||
w.gw.Flush()
|
||||
}
|
||||
|
||||
if fw, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
fw.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker. If the underlying ResponseWriter is a
|
||||
// Hijacker, its Hijack method is returned. Otherwise an error is returned.
|
||||
func (w *GzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("http.Hijacker interface is not supported")
|
||||
}
|
||||
|
||||
// verify Hijacker interface implementation
|
||||
var _ http.Hijacker = &GzipResponseWriter{}
|
||||
|
||||
// MustNewGzipLevelHandler behaves just like NewGzipLevelHandler except that in
|
||||
// an error case it panics rather than returning an error.
|
||||
func MustNewGzipLevelHandler(level int) func(http.Handler) http.Handler {
|
||||
wrap, err := NewGzipLevelHandler(level)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return wrap
|
||||
}
|
||||
|
||||
// NewGzipLevelHandler returns a wrapper function (often known as middleware)
|
||||
// which can be used to wrap an HTTP handler to transparently gzip the response
|
||||
// body if the client supports it (via the Accept-Encoding header). Responses will
|
||||
// be encoded at the given gzip compression level. An error will be returned only
|
||||
// if an invalid gzip compression level is given, so if one can ensure the level
|
||||
// is valid, the returned error can be safely ignored.
|
||||
func NewGzipLevelHandler(level int) (func(http.Handler) http.Handler, error) {
|
||||
if level != gzip.DefaultCompression && (level < gzip.BestSpeed || level > gzip.BestCompression) {
|
||||
return nil, fmt.Errorf("invalid compression level requested: %d", level)
|
||||
}
|
||||
return func(h http.Handler) http.Handler {
|
||||
index := poolIndex(level)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add(vary, acceptEncoding)
|
||||
|
||||
if acceptsGzip(r) {
|
||||
gw := &GzipResponseWriter{
|
||||
ResponseWriter: w,
|
||||
index: index,
|
||||
}
|
||||
defer gw.Close()
|
||||
|
||||
h.ServeHTTP(gw, r)
|
||||
} else {
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GzipHandler wraps an HTTP handler, to transparently gzip the response body if
|
||||
// the client supports it (via the Accept-Encoding header). This will compress at
|
||||
// the default compression level.
|
||||
func GzipHandler(h http.Handler) http.Handler {
|
||||
wrapper, _ := NewGzipLevelHandler(gzip.DefaultCompression)
|
||||
return wrapper(h)
|
||||
}
|
||||
|
||||
// acceptsGzip returns true if the given HTTP request indicates that it will
|
||||
// accept a gzipped response.
|
||||
func acceptsGzip(r *http.Request) bool {
|
||||
acceptedEncodings, _ := parseEncodings(r.Header.Get(acceptEncoding))
|
||||
return acceptedEncodings["gzip"] > 0.0
|
||||
}
|
||||
|
||||
// parseEncodings attempts to parse a list of codings, per RFC 2616, as might
|
||||
// appear in an Accept-Encoding header. It returns a map of content-codings to
|
||||
// quality values, and an error containing the errors encountered. It's probably
|
||||
// safe to ignore those, because silently ignoring errors is how the internet
|
||||
// works.
|
||||
//
|
||||
// See: http://tools.ietf.org/html/rfc2616#section-14.3.
|
||||
func parseEncodings(s string) (codings, error) {
|
||||
c := make(codings)
|
||||
var e []string
|
||||
|
||||
for _, ss := range strings.Split(s, ",") {
|
||||
coding, qvalue, err := parseCoding(ss)
|
||||
|
||||
if err != nil {
|
||||
e = append(e, err.Error())
|
||||
} else {
|
||||
c[coding] = qvalue
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (adammck): Use a proper multi-error struct, so the individual errors
|
||||
// can be extracted if anyone cares.
|
||||
if len(e) > 0 {
|
||||
return c, fmt.Errorf("errors while parsing encodings: %s", strings.Join(e, ", "))
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// parseCoding parses a single conding (content-coding with an optional qvalue),
|
||||
// as might appear in an Accept-Encoding header. It attempts to forgive minor
|
||||
// formatting errors.
|
||||
func parseCoding(s string) (coding string, qvalue float64, err error) {
|
||||
for n, part := range strings.Split(s, ";") {
|
||||
part = strings.TrimSpace(part)
|
||||
qvalue = DEFAULT_QVALUE
|
||||
|
||||
if n == 0 {
|
||||
coding = strings.ToLower(part)
|
||||
} else if strings.HasPrefix(part, "q=") {
|
||||
qvalue, err = strconv.ParseFloat(strings.TrimPrefix(part, "q="), 64)
|
||||
|
||||
if qvalue < 0.0 {
|
||||
qvalue = 0.0
|
||||
} else if qvalue > 1.0 {
|
||||
qvalue = 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if coding == "" {
|
||||
err = fmt.Errorf("empty content-coding")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,276 @@
|
|||
package gziphandler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseEncodings(t *testing.T) {
|
||||
|
||||
examples := map[string]codings{
|
||||
|
||||
// Examples from RFC 2616
|
||||
"compress, gzip": {"compress": 1.0, "gzip": 1.0},
|
||||
"": {},
|
||||
"*": {"*": 1.0},
|
||||
"compress;q=0.5, gzip;q=1.0": {"compress": 0.5, "gzip": 1.0},
|
||||
"gzip;q=1.0, identity; q=0.5, *;q=0": {"gzip": 1.0, "identity": 0.5, "*": 0.0},
|
||||
|
||||
// More random stuff
|
||||
"AAA;q=1": {"aaa": 1.0},
|
||||
"BBB ; q = 2": {"bbb": 1.0},
|
||||
}
|
||||
|
||||
for eg, exp := range examples {
|
||||
act, _ := parseEncodings(eg)
|
||||
assert.Equal(t, exp, act)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipHandler(t *testing.T) {
|
||||
testBody := "aaabbbccc"
|
||||
|
||||
// This just exists to provide something for GzipHandler to wrap.
|
||||
handler := newTestHandler(testBody)
|
||||
|
||||
// requests without accept-encoding are passed along as-is
|
||||
|
||||
req1, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
resp1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp1, req1)
|
||||
res1 := resp1.Result()
|
||||
|
||||
assert.Equal(t, 200, res1.StatusCode)
|
||||
assert.Equal(t, "", res1.Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "Accept-Encoding", res1.Header.Get("Vary"))
|
||||
assert.Equal(t, testBody, resp1.Body.String())
|
||||
|
||||
// but requests with accept-encoding:gzip are compressed if possible
|
||||
|
||||
req2, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
req2.Header.Set("Accept-Encoding", "gzip")
|
||||
resp2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp2, req2)
|
||||
res2 := resp2.Result()
|
||||
|
||||
assert.Equal(t, 200, res2.StatusCode)
|
||||
assert.Equal(t, "gzip", res2.Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "Accept-Encoding", res2.Header.Get("Vary"))
|
||||
assert.Equal(t, gzipStrLevel(testBody, gzip.DefaultCompression), resp2.Body.Bytes())
|
||||
|
||||
// content-type header is correctly set based on uncompressed body
|
||||
|
||||
req3, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
req3.Header.Set("Accept-Encoding", "gzip")
|
||||
res3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res3, req3)
|
||||
|
||||
assert.Equal(t, http.DetectContentType([]byte(testBody)), res3.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
func TestNewGzipLevelHandler(t *testing.T) {
|
||||
testBody := "aaabbbccc"
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
io.WriteString(w, testBody)
|
||||
})
|
||||
|
||||
for lvl := gzip.BestSpeed; lvl <= gzip.BestCompression; lvl++ {
|
||||
wrapper, err := NewGzipLevelHandler(lvl)
|
||||
if !assert.Nil(t, err, "NewGzipLevleHandler returned error for level:", lvl) {
|
||||
continue
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
resp := httptest.NewRecorder()
|
||||
wrapper(handler).ServeHTTP(resp, req)
|
||||
res := resp.Result()
|
||||
|
||||
assert.Equal(t, 200, res.StatusCode)
|
||||
assert.Equal(t, "gzip", res.Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "Accept-Encoding", res.Header.Get("Vary"))
|
||||
assert.Equal(t, gzipStrLevel(testBody, lvl), resp.Body.Bytes())
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGzipLevelHandlerReturnsErrorForInvalidLevels(t *testing.T) {
|
||||
var err error
|
||||
_, err = NewGzipLevelHandler(-42)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, err = NewGzipLevelHandler(42)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestMustNewGzipLevelHandlerWillPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("panic was not called")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustNewGzipLevelHandler(-42)
|
||||
}
|
||||
|
||||
func TestGzipHandlerNoBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
statusCode int
|
||||
contentEncoding string
|
||||
bodyLen int
|
||||
}{
|
||||
// Body must be empty.
|
||||
{http.StatusNoContent, "", 0},
|
||||
{http.StatusNotModified, "", 0},
|
||||
// Body is going to get gzip'd no matter what.
|
||||
{http.StatusOK, "gzip", 23},
|
||||
}
|
||||
|
||||
for num, test := range tests {
|
||||
handler := GzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.statusCode)
|
||||
}))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
// TODO: in Go1.7 httptest.NewRequest was introduced this should be used
|
||||
// once 1.6 is not longer supported.
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/"},
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMinor: 1,
|
||||
RemoteAddr: "192.0.2.1:1234",
|
||||
Header: make(http.Header),
|
||||
}
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
body, err := ioutil.ReadAll(rec.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error reading response body: %v", err)
|
||||
}
|
||||
|
||||
header := rec.Header()
|
||||
assert.Equal(t, test.contentEncoding, header.Get("Content-Encoding"), fmt.Sprintf("for test iteration %d", num))
|
||||
assert.Equal(t, "Accept-Encoding", header.Get("Vary"), fmt.Sprintf("for test iteration %d", num))
|
||||
assert.Equal(t, test.bodyLen, len(body), fmt.Sprintf("for test iteration %d", num))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipHandlerContentLength(t *testing.T) {
|
||||
b := []byte("testtesttesttesttesttesttesttesttesttesttesttesttest")
|
||||
handler := GzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(b)))
|
||||
w.Write(b)
|
||||
}))
|
||||
// httptest.NewRecorder doesn't give you access to the Content-Length
|
||||
// header so instead, we create a server on a random port and make
|
||||
// a request to that instead
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating listen socket: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
srv := &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
go srv.Serve(ln)
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/", Scheme: "http", Host: ln.Addr().String()},
|
||||
Header: make(http.Header),
|
||||
Close: true,
|
||||
}
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error making http request: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error reading response body: %v", err)
|
||||
}
|
||||
|
||||
l, err := strconv.Atoi(res.Header.Get("Content-Length"))
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error parsing Content-Length: %v", err)
|
||||
}
|
||||
assert.Len(t, body, l)
|
||||
assert.Equal(t, "gzip", res.Header.Get("Content-Encoding"))
|
||||
assert.NotEqual(t, b, body)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func BenchmarkGzipHandler_S2k(b *testing.B) { benchmark(b, false, 2048) }
|
||||
func BenchmarkGzipHandler_S20k(b *testing.B) { benchmark(b, false, 20480) }
|
||||
func BenchmarkGzipHandler_S100k(b *testing.B) { benchmark(b, false, 102400) }
|
||||
func BenchmarkGzipHandler_P2k(b *testing.B) { benchmark(b, true, 2048) }
|
||||
func BenchmarkGzipHandler_P20k(b *testing.B) { benchmark(b, true, 20480) }
|
||||
func BenchmarkGzipHandler_P100k(b *testing.B) { benchmark(b, true, 102400) }
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func gzipStrLevel(s string, lvl int) []byte {
|
||||
var b bytes.Buffer
|
||||
w, _ := gzip.NewWriterLevel(&b, lvl)
|
||||
io.WriteString(w, s)
|
||||
w.Close()
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func benchmark(b *testing.B, parallel bool, size int) {
|
||||
bin, err := ioutil.ReadFile("testdata/benchmark.json")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
handler := newTestHandler(string(bin[:size]))
|
||||
|
||||
if parallel {
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
runBenchmark(b, req, handler)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
runBenchmark(b, req, handler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runBenchmark(b *testing.B, req *http.Request, handler http.Handler) {
|
||||
res := httptest.NewRecorder()
|
||||
handler.ServeHTTP(res, req)
|
||||
if code := res.Code; code != 200 {
|
||||
b.Fatalf("Expected 200 but got %d", code)
|
||||
} else if blen := res.Body.Len(); blen < 500 {
|
||||
b.Fatalf("Expected complete response body, but got %d bytes", blen)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestHandler(body string) http.Handler {
|
||||
return GzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, body)
|
||||
}))
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
logrus
|
|
@ -0,0 +1,10 @@
|
|||
language: go
|
||||
go:
|
||||
- 1.3
|
||||
- 1.4
|
||||
- 1.5
|
||||
- 1.6
|
||||
- tip
|
||||
install:
|
||||
- go get -t ./...
|
||||
script: GOMAXPROCS=4 GORACE="halt_on_error=1" go test -race -v ./...
|
|
@ -0,0 +1,66 @@
|
|||
# 0.10.0
|
||||
|
||||
* feature: Add a test hook (#180)
|
||||
* feature: `ParseLevel` is now case-insensitive (#326)
|
||||
* feature: `FieldLogger` interface that generalizes `Logger` and `Entry` (#308)
|
||||
* performance: avoid re-allocations on `WithFields` (#335)
|
||||
|
||||
# 0.9.0
|
||||
|
||||
* logrus/text_formatter: don't emit empty msg
|
||||
* logrus/hooks/airbrake: move out of main repository
|
||||
* logrus/hooks/sentry: move out of main repository
|
||||
* logrus/hooks/papertrail: move out of main repository
|
||||
* logrus/hooks/bugsnag: move out of main repository
|
||||
* logrus/core: run tests with `-race`
|
||||
* logrus/core: detect TTY based on `stderr`
|
||||
* logrus/core: support `WithError` on logger
|
||||
* logrus/core: Solaris support
|
||||
|
||||
# 0.8.7
|
||||
|
||||
* logrus/core: fix possible race (#216)
|
||||
* logrus/doc: small typo fixes and doc improvements
|
||||
|
||||
|
||||
# 0.8.6
|
||||
|
||||
* hooks/raven: allow passing an initialized client
|
||||
|
||||
# 0.8.5
|
||||
|
||||
* logrus/core: revert #208
|
||||
|
||||
# 0.8.4
|
||||
|
||||
* formatter/text: fix data race (#218)
|
||||
|
||||
# 0.8.3
|
||||
|
||||
* logrus/core: fix entry log level (#208)
|
||||
* logrus/core: improve performance of text formatter by 40%
|
||||
* logrus/core: expose `LevelHooks` type
|
||||
* logrus/core: add support for DragonflyBSD and NetBSD
|
||||
* formatter/text: print structs more verbosely
|
||||
|
||||
# 0.8.2
|
||||
|
||||
* logrus: fix more Fatal family functions
|
||||
|
||||
# 0.8.1
|
||||
|
||||
* logrus: fix not exiting on `Fatalf` and `Fatalln`
|
||||
|
||||
# 0.8.0
|
||||
|
||||
* logrus: defaults to stderr instead of stdout
|
||||
* hooks/sentry: add special field for `*http.Request`
|
||||
* formatter/text: ignore Windows for colors
|
||||
|
||||
# 0.7.3
|
||||
|
||||
* formatter/\*: allow configuration of timestamp layout
|
||||
|
||||
# 0.7.2
|
||||
|
||||
* formatter/text: Add configuration option for time format (#158)
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Simon Eskildsen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,421 @@
|
|||
# Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:"/> [![Build Status](https://travis-ci.org/Sirupsen/logrus.svg?branch=master)](https://travis-ci.org/Sirupsen/logrus) [![GoDoc](https://godoc.org/github.com/Sirupsen/logrus?status.svg)](https://godoc.org/github.com/Sirupsen/logrus)
|
||||
|
||||
Logrus is a structured logger for Go (golang), completely API compatible with
|
||||
the standard library logger. [Godoc][godoc]. **Please note the Logrus API is not
|
||||
yet stable (pre 1.0). Logrus itself is completely stable and has been used in
|
||||
many large deployments. The core API is unlikely to change much but please
|
||||
version control your Logrus to make sure you aren't fetching latest `master` on
|
||||
every build.**
|
||||
|
||||
Nicely color-coded in development (when a TTY is attached, otherwise just
|
||||
plain text):
|
||||
|
||||
![Colored](http://i.imgur.com/PY7qMwd.png)
|
||||
|
||||
With `log.SetFormatter(&log.JSONFormatter{})`, for easy parsing by logstash
|
||||
or Splunk:
|
||||
|
||||
```json
|
||||
{"animal":"walrus","level":"info","msg":"A group of walrus emerges from the
|
||||
ocean","size":10,"time":"2014-03-10 19:57:38.562264131 -0400 EDT"}
|
||||
|
||||
{"level":"warning","msg":"The group's number increased tremendously!",
|
||||
"number":122,"omg":true,"time":"2014-03-10 19:57:38.562471297 -0400 EDT"}
|
||||
|
||||
{"animal":"walrus","level":"info","msg":"A giant walrus appears!",
|
||||
"size":10,"time":"2014-03-10 19:57:38.562500591 -0400 EDT"}
|
||||
|
||||
{"animal":"walrus","level":"info","msg":"Tremendously sized cow enters the ocean.",
|
||||
"size":9,"time":"2014-03-10 19:57:38.562527896 -0400 EDT"}
|
||||
|
||||
{"level":"fatal","msg":"The ice breaks!","number":100,"omg":true,
|
||||
"time":"2014-03-10 19:57:38.562543128 -0400 EDT"}
|
||||
```
|
||||
|
||||
With the default `log.SetFormatter(&log.TextFormatter{})` when a TTY is not
|
||||
attached, the output is compatible with the
|
||||
[logfmt](http://godoc.org/github.com/kr/logfmt) format:
|
||||
|
||||
```text
|
||||
time="2015-03-26T01:27:38-04:00" level=debug msg="Started observing beach" animal=walrus number=8
|
||||
time="2015-03-26T01:27:38-04:00" level=info msg="A group of walrus emerges from the ocean" animal=walrus size=10
|
||||
time="2015-03-26T01:27:38-04:00" level=warning msg="The group's number increased tremendously!" number=122 omg=true
|
||||
time="2015-03-26T01:27:38-04:00" level=debug msg="Temperature changes" temperature=-4
|
||||
time="2015-03-26T01:27:38-04:00" level=panic msg="It's over 9000!" animal=orca size=9009
|
||||
time="2015-03-26T01:27:38-04:00" level=fatal msg="The ice breaks!" err=&{0x2082280c0 map[animal:orca size:9009] 2015-03-26 01:27:38.441574009 -0400 EDT panic It's over 9000!} number=100 omg=true
|
||||
exit status 1
|
||||
```
|
||||
|
||||
#### Example
|
||||
|
||||
The simplest way to use Logrus is simply the package-level exported logger:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.WithFields(log.Fields{
|
||||
"animal": "walrus",
|
||||
}).Info("A walrus appears")
|
||||
}
|
||||
```
|
||||
|
||||
Note that it's completely api-compatible with the stdlib logger, so you can
|
||||
replace your `log` imports everywhere with `log "github.com/Sirupsen/logrus"`
|
||||
and you'll now have the flexibility of Logrus. You can customize it all you
|
||||
want:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Log as JSON instead of the default ASCII formatter.
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
|
||||
// Output to stderr instead of stdout, could also be a file.
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
// Only log the warning severity or above.
|
||||
log.SetLevel(log.WarnLevel)
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.WithFields(log.Fields{
|
||||
"animal": "walrus",
|
||||
"size": 10,
|
||||
}).Info("A group of walrus emerges from the ocean")
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"omg": true,
|
||||
"number": 122,
|
||||
}).Warn("The group's number increased tremendously!")
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"omg": true,
|
||||
"number": 100,
|
||||
}).Fatal("The ice breaks!")
|
||||
|
||||
// A common pattern is to re-use fields between logging statements by re-using
|
||||
// the logrus.Entry returned from WithFields()
|
||||
contextLogger := log.WithFields(log.Fields{
|
||||
"common": "this is a common field",
|
||||
"other": "I also should be logged always",
|
||||
})
|
||||
|
||||
contextLogger.Info("I'll be logged with common and other field")
|
||||
contextLogger.Info("Me too")
|
||||
}
|
||||
```
|
||||
|
||||
For more advanced usage such as logging to multiple locations from the same
|
||||
application, you can also create an instance of the `logrus` Logger:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Create a new instance of the logger. You can have any number of instances.
|
||||
var log = logrus.New()
|
||||
|
||||
func main() {
|
||||
// The API for setting attributes is a little different than the package level
|
||||
// exported logger. See Godoc.
|
||||
log.Out = os.Stderr
|
||||
|
||||
log.WithFields(logrus.Fields{
|
||||
"animal": "walrus",
|
||||
"size": 10,
|
||||
}).Info("A group of walrus emerges from the ocean")
|
||||
}
|
||||
```
|
||||
|
||||
#### Fields
|
||||
|
||||
Logrus encourages careful, structured logging though logging fields instead of
|
||||
long, unparseable error messages. For example, instead of: `log.Fatalf("Failed
|
||||
to send event %s to topic %s with key %d")`, you should log the much more
|
||||
discoverable:
|
||||
|
||||
```go
|
||||
log.WithFields(log.Fields{
|
||||
"event": event,
|
||||
"topic": topic,
|
||||
"key": key,
|
||||
}).Fatal("Failed to send event")
|
||||
```
|
||||
|
||||
We've found this API forces you to think about logging in a way that produces
|
||||
much more useful logging messages. We've been in countless situations where just
|
||||
a single added field to a log statement that was already there would've saved us
|
||||
hours. The `WithFields` call is optional.
|
||||
|
||||
In general, with Logrus using any of the `printf`-family functions should be
|
||||
seen as a hint you should add a field, however, you can still use the
|
||||
`printf`-family functions with Logrus.
|
||||
|
||||
#### Hooks
|
||||
|
||||
You can add hooks for logging levels. For example to send errors to an exception
|
||||
tracking service on `Error`, `Fatal` and `Panic`, info to StatsD or log to
|
||||
multiple places simultaneously, e.g. syslog.
|
||||
|
||||
Logrus comes with [built-in hooks](hooks/). Add those, or your custom hook, in
|
||||
`init`:
|
||||
|
||||
```go
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"gopkg.in/gemnasium/logrus-airbrake-hook.v2" // the package is named "aibrake"
|
||||
logrus_syslog "github.com/Sirupsen/logrus/hooks/syslog"
|
||||
"log/syslog"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
// Use the Airbrake hook to report errors that have Error severity or above to
|
||||
// an exception tracker. You can create custom hooks, see the Hooks section.
|
||||
log.AddHook(airbrake.NewHook(123, "xyz", "production"))
|
||||
|
||||
hook, err := logrus_syslog.NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "")
|
||||
if err != nil {
|
||||
log.Error("Unable to connect to local syslog daemon")
|
||||
} else {
|
||||
log.AddHook(hook)
|
||||
}
|
||||
}
|
||||
```
|
||||
Note: Syslog hook also support connecting to local syslog (Ex. "/dev/log" or "/var/run/syslog" or "/var/run/log"). For the detail, please check the [syslog hook README](hooks/syslog/README.md).
|
||||
|
||||
| Hook | Description |
|
||||
| ----- | ----------- |
|
||||
| [Airbrake](https://github.com/gemnasium/logrus-airbrake-hook) | Send errors to the Airbrake API V3. Uses the official [`gobrake`](https://github.com/airbrake/gobrake) behind the scenes. |
|
||||
| [Airbrake "legacy"](https://github.com/gemnasium/logrus-airbrake-legacy-hook) | Send errors to an exception tracking service compatible with the Airbrake API V2. Uses [`airbrake-go`](https://github.com/tobi/airbrake-go) behind the scenes. |
|
||||
| [Papertrail](https://github.com/polds/logrus-papertrail-hook) | Send errors to the [Papertrail](https://papertrailapp.com) hosted logging service via UDP. |
|
||||
| [Syslog](https://github.com/Sirupsen/logrus/blob/master/hooks/syslog/syslog.go) | Send errors to remote syslog server. Uses standard library `log/syslog` behind the scenes. |
|
||||
| [Bugsnag](https://github.com/Shopify/logrus-bugsnag/blob/master/bugsnag.go) | Send errors to the Bugsnag exception tracking service. |
|
||||
| [Sentry](https://github.com/evalphobia/logrus_sentry) | Send errors to the Sentry error logging and aggregation service. |
|
||||
| [Hiprus](https://github.com/nubo/hiprus) | Send errors to a channel in hipchat. |
|
||||
| [Logrusly](https://github.com/sebest/logrusly) | Send logs to [Loggly](https://www.loggly.com/) |
|
||||
| [Slackrus](https://github.com/johntdyer/slackrus) | Hook for Slack chat. |
|
||||
| [Journalhook](https://github.com/wercker/journalhook) | Hook for logging to `systemd-journald` |
|
||||
| [Graylog](https://github.com/gemnasium/logrus-graylog-hook) | Hook for logging to [Graylog](http://graylog2.org/) |
|
||||
| [Raygun](https://github.com/squirkle/logrus-raygun-hook) | Hook for logging to [Raygun.io](http://raygun.io/) |
|
||||
| [LFShook](https://github.com/rifflock/lfshook) | Hook for logging to the local filesystem |
|
||||
| [Honeybadger](https://github.com/agonzalezro/logrus_honeybadger) | Hook for sending exceptions to Honeybadger |
|
||||
| [Mail](https://github.com/zbindenren/logrus_mail) | Hook for sending exceptions via mail |
|
||||
| [Rollrus](https://github.com/heroku/rollrus) | Hook for sending errors to rollbar |
|
||||
| [Fluentd](https://github.com/evalphobia/logrus_fluent) | Hook for logging to fluentd |
|
||||
| [Mongodb](https://github.com/weekface/mgorus) | Hook for logging to mongodb |
|
||||
| [Influxus] (http://github.com/vlad-doru/influxus) | Hook for concurrently logging to [InfluxDB] (http://influxdata.com/) |
|
||||
| [InfluxDB](https://github.com/Abramovic/logrus_influxdb) | Hook for logging to influxdb |
|
||||
| [Octokit](https://github.com/dorajistyle/logrus-octokit-hook) | Hook for logging to github via octokit |
|
||||
| [DeferPanic](https://github.com/deferpanic/dp-logrus) | Hook for logging to DeferPanic |
|
||||
| [Redis-Hook](https://github.com/rogierlommers/logrus-redis-hook) | Hook for logging to a ELK stack (through Redis) |
|
||||
| [Amqp-Hook](https://github.com/vladoatanasov/logrus_amqp) | Hook for logging to Amqp broker (Like RabbitMQ) |
|
||||
| [KafkaLogrus](https://github.com/goibibo/KafkaLogrus) | Hook for logging to kafka |
|
||||
| [Typetalk](https://github.com/dragon3/logrus-typetalk-hook) | Hook for logging to [Typetalk](https://www.typetalk.in/) |
|
||||
| [ElasticSearch](https://github.com/sohlich/elogrus) | Hook for logging to ElasticSearch|
|
||||
| [Sumorus](https://github.com/doublefree/sumorus) | Hook for logging to [SumoLogic](https://www.sumologic.com/)|
|
||||
| [Logstash](https://github.com/bshuster-repo/logrus-logstash-hook) | Hook for logging to [Logstash](https://www.elastic.co/products/logstash) |
|
||||
| [Logmatic.io](https://github.com/logmatic/logmatic-go) | Hook for logging to [Logmatic.io](http://logmatic.io/) |
|
||||
|
||||
|
||||
#### Level logging
|
||||
|
||||
Logrus has six logging levels: Debug, Info, Warning, Error, Fatal and Panic.
|
||||
|
||||
```go
|
||||
log.Debug("Useful debugging information.")
|
||||
log.Info("Something noteworthy happened!")
|
||||
log.Warn("You should probably take a look at this.")
|
||||
log.Error("Something failed but I'm not quitting.")
|
||||
// Calls os.Exit(1) after logging
|
||||
log.Fatal("Bye.")
|
||||
// Calls panic() after logging
|
||||
log.Panic("I'm bailing.")
|
||||
```
|
||||
|
||||
You can set the logging level on a `Logger`, then it will only log entries with
|
||||
that severity or anything above it:
|
||||
|
||||
```go
|
||||
// Will log anything that is info or above (warn, error, fatal, panic). Default.
|
||||
log.SetLevel(log.InfoLevel)
|
||||
```
|
||||
|
||||
It may be useful to set `log.Level = logrus.DebugLevel` in a debug or verbose
|
||||
environment if your application has that.
|
||||
|
||||
#### Entries
|
||||
|
||||
Besides the fields added with `WithField` or `WithFields` some fields are
|
||||
automatically added to all logging events:
|
||||
|
||||
1. `time`. The timestamp when the entry was created.
|
||||
2. `msg`. The logging message passed to `{Info,Warn,Error,Fatal,Panic}` after
|
||||
the `AddFields` call. E.g. `Failed to send event.`
|
||||
3. `level`. The logging level. E.g. `info`.
|
||||
|
||||
#### Environments
|
||||
|
||||
Logrus has no notion of environment.
|
||||
|
||||
If you wish for hooks and formatters to only be used in specific environments,
|
||||
you should handle that yourself. For example, if your application has a global
|
||||
variable `Environment`, which is a string representation of the environment you
|
||||
could do:
|
||||
|
||||
```go
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
||||
init() {
|
||||
// do something here to set environment depending on an environment variable
|
||||
// or command-line flag
|
||||
if Environment == "production" {
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
} else {
|
||||
// The TextFormatter is default, you don't actually have to do this.
|
||||
log.SetFormatter(&log.TextFormatter{})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This configuration is how `logrus` was intended to be used, but JSON in
|
||||
production is mostly only useful if you do log aggregation with tools like
|
||||
Splunk or Logstash.
|
||||
|
||||
#### Formatters
|
||||
|
||||
The built-in logging formatters are:
|
||||
|
||||
* `logrus.TextFormatter`. Logs the event in colors if stdout is a tty, otherwise
|
||||
without colors.
|
||||
* *Note:* to force colored output when there is no TTY, set the `ForceColors`
|
||||
field to `true`. To force no colored output even if there is a TTY set the
|
||||
`DisableColors` field to `true`
|
||||
* `logrus.JSONFormatter`. Logs fields as JSON.
|
||||
|
||||
Third party logging formatters:
|
||||
|
||||
* [`logstash`](https://github.com/bshuster-repo/logrus-logstash-hook). Logs fields as [Logstash](http://logstash.net) Events.
|
||||
* [`prefixed`](https://github.com/x-cray/logrus-prefixed-formatter). Displays log entry source along with alternative layout.
|
||||
* [`zalgo`](https://github.com/aybabtme/logzalgo). Invoking the P͉̫o̳̼̊w̖͈̰͎e̬͔̭͂r͚̼̹̲ ̫͓͉̳͈ō̠͕͖̚f̝͍̠ ͕̲̞͖͑Z̖̫̤̫ͪa͉̬͈̗l͖͎g̳̥o̰̥̅!̣͔̲̻͊̄ ̙̘̦̹̦.
|
||||
|
||||
You can define your formatter by implementing the `Formatter` interface,
|
||||
requiring a `Format` method. `Format` takes an `*Entry`. `entry.Data` is a
|
||||
`Fields` type (`map[string]interface{}`) with all your fields as well as the
|
||||
default ones (see Entries section above):
|
||||
|
||||
```go
|
||||
type MyJSONFormatter struct {
|
||||
}
|
||||
|
||||
log.SetFormatter(new(MyJSONFormatter))
|
||||
|
||||
func (f *MyJSONFormatter) Format(entry *Entry) ([]byte, error) {
|
||||
// Note this doesn't include Time, Level and Message which are available on
|
||||
// the Entry. Consult `godoc` on information about those fields or read the
|
||||
// source of the official loggers.
|
||||
serialized, err := json.Marshal(entry.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err)
|
||||
}
|
||||
return append(serialized, '\n'), nil
|
||||
}
|
||||
```
|
||||
|
||||
#### Logger as an `io.Writer`
|
||||
|
||||
Logrus can be transformed into an `io.Writer`. That writer is the end of an `io.Pipe` and it is your responsibility to close it.
|
||||
|
||||
```go
|
||||
w := logger.Writer()
|
||||
defer w.Close()
|
||||
|
||||
srv := http.Server{
|
||||
// create a stdlib log.Logger that writes to
|
||||
// logrus.Logger.
|
||||
ErrorLog: log.New(w, "", 0),
|
||||
}
|
||||
```
|
||||
|
||||
Each line written to that writer will be printed the usual way, using formatters
|
||||
and hooks. The level for those entries is `info`.
|
||||
|
||||
#### Rotation
|
||||
|
||||
Log rotation is not provided with Logrus. Log rotation should be done by an
|
||||
external program (like `logrotate(8)`) that can compress and delete old log
|
||||
entries. It should not be a feature of the application-level logger.
|
||||
|
||||
#### Tools
|
||||
|
||||
| Tool | Description |
|
||||
| ---- | ----------- |
|
||||
|[Logrus Mate](https://github.com/gogap/logrus_mate)|Logrus mate is a tool for Logrus to manage loggers, you can initial logger's level, hook and formatter by config file, the logger will generated with different config at different environment.|
|
||||
|
||||
#### Testing
|
||||
|
||||
Logrus has a built in facility for asserting the presence of log messages. This is implemented through the `test` hook and provides:
|
||||
|
||||
* decorators for existing logger (`test.NewLocal` and `test.NewGlobal`) which basically just add the `test` hook
|
||||
* a test logger (`test.NewNullLogger`) that just records log messages (and does not output any):
|
||||
|
||||
```go
|
||||
logger, hook := NewNullLogger()
|
||||
logger.Error("Hello error")
|
||||
|
||||
assert.Equal(1, len(hook.Entries))
|
||||
assert.Equal(logrus.ErrorLevel, hook.LastEntry().Level)
|
||||
assert.Equal("Hello error", hook.LastEntry().Message)
|
||||
|
||||
hook.Reset()
|
||||
assert.Nil(hook.LastEntry())
|
||||
```
|
||||
|
||||
#### Fatal handlers
|
||||
|
||||
Logrus can register one or more functions that will be called when any `fatal`
|
||||
level message is logged. The registered handlers will be executed before
|
||||
logrus performs a `os.Exit(1)`. This behavior may be helpful if callers need
|
||||
to gracefully shutdown. Unlike a `panic("Something went wrong...")` call which can be intercepted with a deferred `recover` a call to `os.Exit(1)` can not be intercepted.
|
||||
|
||||
```
|
||||
...
|
||||
handler := func() {
|
||||
// gracefully shutdown something...
|
||||
}
|
||||
logrus.RegisterExitHandler(handler)
|
||||
...
|
||||
```
|
||||
|
||||
#### Thread safty
|
||||
|
||||
By default Logger is protected by mutex for concurrent writes, this mutex is invoked when calling hooks and writing logs.
|
||||
If you are sure such locking is not needed, you can call logger.SetNoLock() to disable the locking.
|
||||
|
||||
Situation when locking is not needed includes:
|
||||
|
||||
* You have no hooks registered, or hooks calling is already thread-safe.
|
||||
|
||||
* Writing to logger.Out is already thread-safe, for example:
|
||||
|
||||
1) logger.Out is protected by locks.
|
||||
|
||||
2) logger.Out is a os.File handler opened with `O_APPEND` flag, and every write is smaller than 4k. (This allow multi-thread/multi-process writing)
|
||||
|
||||
(Refer to http://www.notthewizard.com/2014/06/17/are-files-appends-really-atomic/)
|
|
@ -0,0 +1,64 @@
|
|||
package logrus
|
||||
|
||||
// The following code was sourced and modified from the
|
||||
// https://bitbucket.org/tebeka/atexit package governed by the following license:
|
||||
//
|
||||
// Copyright (c) 2012 Miki Tebeka <miki.tebeka@gmail.com>.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var handlers = []func(){}
|
||||
|
||||
func runHandler(handler func()) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Error: Logrus exit handler error:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
handler()
|
||||
}
|
||||
|
||||
func runHandlers() {
|
||||
for _, handler := range handlers {
|
||||
runHandler(handler)
|
||||
}
|
||||
}
|
||||
|
||||
// Exit runs all the Logrus atexit handlers and then terminates the program using os.Exit(code)
|
||||
func Exit(code int) {
|
||||
runHandlers()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// RegisterExitHandler adds a Logrus Exit handler, call logrus.Exit to invoke
|
||||
// all handlers. The handlers will also be invoked when any Fatal log entry is
|
||||
// made.
|
||||
//
|
||||
// This method is useful when a caller wishes to use logrus to log a fatal
|
||||
// message but also needs to gracefully shutdown. An example usecase could be
|
||||
// closing database connections, or sending a alert that the application is
|
||||
// closing.
|
||||
func RegisterExitHandler(handler func()) {
|
||||
handlers = append(handlers, handler)
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
current := len(handlers)
|
||||
RegisterExitHandler(func() {})
|
||||
if len(handlers) != current+1 {
|
||||
t.Fatalf("can't add handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
gofile := "/tmp/testprog.go"
|
||||
if err := ioutil.WriteFile(gofile, testprog, 0666); err != nil {
|
||||
t.Fatalf("can't create go file")
|
||||
}
|
||||
|
||||
outfile := "/tmp/testprog.out"
|
||||
arg := time.Now().UTC().String()
|
||||
err := exec.Command("go", "run", gofile, outfile, arg).Run()
|
||||
if err == nil {
|
||||
t.Fatalf("completed normally, should have failed")
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(outfile)
|
||||
if err != nil {
|
||||
t.Fatalf("can't read output file %s", outfile)
|
||||
}
|
||||
|
||||
if string(data) != arg {
|
||||
t.Fatalf("bad data")
|
||||
}
|
||||
}
|
||||
|
||||
var testprog = []byte(`
|
||||
// Test program for atexit, gets output file and data as arguments and writes
|
||||
// data to output file in atexit handler.
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/Sirupsen/logrus"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
var outfile = ""
|
||||
var data = ""
|
||||
|
||||
func handler() {
|
||||
ioutil.WriteFile(outfile, []byte(data), 0666)
|
||||
}
|
||||
|
||||
func badHandler() {
|
||||
n := 0
|
||||
fmt.Println(1/n)
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
outfile = flag.Arg(0)
|
||||
data = flag.Arg(1)
|
||||
|
||||
logrus.RegisterExitHandler(handler)
|
||||
logrus.RegisterExitHandler(badHandler)
|
||||
logrus.Fatal("Bye bye")
|
||||
}
|
||||
`)
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
Package logrus is a structured logger for Go, completely API compatible with the standard library logger.
|
||||
|
||||
|
||||
The simplest way to use Logrus is simply the package-level exported logger:
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.WithFields(log.Fields{
|
||||
"animal": "walrus",
|
||||
"number": 1,
|
||||
"size": 10,
|
||||
}).Info("A walrus appears")
|
||||
}
|
||||
|
||||
Output:
|
||||
time="2015-09-07T08:48:33Z" level=info msg="A walrus appears" animal=walrus number=1 size=10
|
||||
|
||||
For a full guide visit https://github.com/Sirupsen/logrus
|
||||
*/
|
||||
package logrus
|
|
@ -0,0 +1,275 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var bufferPool *sync.Pool
|
||||
|
||||
func init() {
|
||||
bufferPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Defines the key when adding errors using WithError.
|
||||
var ErrorKey = "error"
|
||||
|
||||
// An entry is the final or intermediate Logrus logging entry. It contains all
|
||||
// the fields passed with WithField{,s}. It's finally logged when Debug, Info,
|
||||
// Warn, Error, Fatal or Panic is called on it. These objects can be reused and
|
||||
// passed around as much as you wish to avoid field duplication.
|
||||
type Entry struct {
|
||||
Logger *Logger
|
||||
|
||||
// Contains all the fields set by the user.
|
||||
Data Fields
|
||||
|
||||
// Time at which the log entry was created
|
||||
Time time.Time
|
||||
|
||||
// Level the log entry was logged at: Debug, Info, Warn, Error, Fatal or Panic
|
||||
Level Level
|
||||
|
||||
// Message passed to Debug, Info, Warn, Error, Fatal or Panic
|
||||
Message string
|
||||
|
||||
// When formatter is called in entry.log(), an Buffer may be set to entry
|
||||
Buffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func NewEntry(logger *Logger) *Entry {
|
||||
return &Entry{
|
||||
Logger: logger,
|
||||
// Default is three fields, give a little extra room
|
||||
Data: make(Fields, 5),
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the string representation from the reader and ultimately the
|
||||
// formatter.
|
||||
func (entry *Entry) String() (string, error) {
|
||||
serialized, err := entry.Logger.Formatter.Format(entry)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
str := string(serialized)
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// Add an error as single field (using the key defined in ErrorKey) to the Entry.
|
||||
func (entry *Entry) WithError(err error) *Entry {
|
||||
return entry.WithField(ErrorKey, err)
|
||||
}
|
||||
|
||||
// Add a single field to the Entry.
|
||||
func (entry *Entry) WithField(key string, value interface{}) *Entry {
|
||||
return entry.WithFields(Fields{key: value})
|
||||
}
|
||||
|
||||
// Add a map of fields to the Entry.
|
||||
func (entry *Entry) WithFields(fields Fields) *Entry {
|
||||
data := make(Fields, len(entry.Data)+len(fields))
|
||||
for k, v := range entry.Data {
|
||||
data[k] = v
|
||||
}
|
||||
for k, v := range fields {
|
||||
data[k] = v
|
||||
}
|
||||
return &Entry{Logger: entry.Logger, Data: data}
|
||||
}
|
||||
|
||||
// This function is not declared with a pointer value because otherwise
|
||||
// race conditions will occur when using multiple goroutines
|
||||
func (entry Entry) log(level Level, msg string) {
|
||||
var buffer *bytes.Buffer
|
||||
entry.Time = time.Now()
|
||||
entry.Level = level
|
||||
entry.Message = msg
|
||||
|
||||
if err := entry.Logger.Hooks.Fire(level, &entry); err != nil {
|
||||
entry.Logger.mu.Lock()
|
||||
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
|
||||
entry.Logger.mu.Unlock()
|
||||
}
|
||||
buffer = bufferPool.Get().(*bytes.Buffer)
|
||||
buffer.Reset()
|
||||
defer bufferPool.Put(buffer)
|
||||
entry.Buffer = buffer
|
||||
serialized, err := entry.Logger.Formatter.Format(&entry)
|
||||
entry.Buffer = nil
|
||||
if err != nil {
|
||||
entry.Logger.mu.Lock()
|
||||
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
|
||||
entry.Logger.mu.Unlock()
|
||||
} else {
|
||||
entry.Logger.mu.Lock()
|
||||
_, err = entry.Logger.Out.Write(serialized)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||
}
|
||||
entry.Logger.mu.Unlock()
|
||||
}
|
||||
|
||||
// To avoid Entry#log() returning a value that only would make sense for
|
||||
// panic() to use in Entry#Panic(), we avoid the allocation by checking
|
||||
// directly here.
|
||||
if level <= PanicLevel {
|
||||
panic(&entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Debug(args ...interface{}) {
|
||||
if entry.Logger.Level >= DebugLevel {
|
||||
entry.log(DebugLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Print(args ...interface{}) {
|
||||
entry.Info(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Info(args ...interface{}) {
|
||||
if entry.Logger.Level >= InfoLevel {
|
||||
entry.log(InfoLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warn(args ...interface{}) {
|
||||
if entry.Logger.Level >= WarnLevel {
|
||||
entry.log(WarnLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warning(args ...interface{}) {
|
||||
entry.Warn(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Error(args ...interface{}) {
|
||||
if entry.Logger.Level >= ErrorLevel {
|
||||
entry.log(ErrorLevel, fmt.Sprint(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Fatal(args ...interface{}) {
|
||||
if entry.Logger.Level >= FatalLevel {
|
||||
entry.log(FatalLevel, fmt.Sprint(args...))
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (entry *Entry) Panic(args ...interface{}) {
|
||||
if entry.Logger.Level >= PanicLevel {
|
||||
entry.log(PanicLevel, fmt.Sprint(args...))
|
||||
}
|
||||
panic(fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
// Entry Printf family functions
|
||||
|
||||
func (entry *Entry) Debugf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= DebugLevel {
|
||||
entry.Debug(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Infof(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= InfoLevel {
|
||||
entry.Info(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Printf(format string, args ...interface{}) {
|
||||
entry.Infof(format, args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Warnf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= WarnLevel {
|
||||
entry.Warn(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warningf(format string, args ...interface{}) {
|
||||
entry.Warnf(format, args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Errorf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= ErrorLevel {
|
||||
entry.Error(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Fatalf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= FatalLevel {
|
||||
entry.Fatal(fmt.Sprintf(format, args...))
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (entry *Entry) Panicf(format string, args ...interface{}) {
|
||||
if entry.Logger.Level >= PanicLevel {
|
||||
entry.Panic(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// Entry Println family functions
|
||||
|
||||
func (entry *Entry) Debugln(args ...interface{}) {
|
||||
if entry.Logger.Level >= DebugLevel {
|
||||
entry.Debug(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Infoln(args ...interface{}) {
|
||||
if entry.Logger.Level >= InfoLevel {
|
||||
entry.Info(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Println(args ...interface{}) {
|
||||
entry.Infoln(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Warnln(args ...interface{}) {
|
||||
if entry.Logger.Level >= WarnLevel {
|
||||
entry.Warn(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Warningln(args ...interface{}) {
|
||||
entry.Warnln(args...)
|
||||
}
|
||||
|
||||
func (entry *Entry) Errorln(args ...interface{}) {
|
||||
if entry.Logger.Level >= ErrorLevel {
|
||||
entry.Error(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (entry *Entry) Fatalln(args ...interface{}) {
|
||||
if entry.Logger.Level >= FatalLevel {
|
||||
entry.Fatal(entry.sprintlnn(args...))
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (entry *Entry) Panicln(args ...interface{}) {
|
||||
if entry.Logger.Level >= PanicLevel {
|
||||
entry.Panic(entry.sprintlnn(args...))
|
||||
}
|
||||
}
|
||||
|
||||
// Sprintlnn => Sprint no newline. This is to get the behavior of how
|
||||
// fmt.Sprintln where spaces are always added between operands, regardless of
|
||||
// their type. Instead of vendoring the Sprintln implementation to spare a
|
||||
// string allocation, we do the simplest thing.
|
||||
func (entry *Entry) sprintlnn(args ...interface{}) string {
|
||||
msg := fmt.Sprintln(args...)
|
||||
return msg[:len(msg)-1]
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEntryWithError(t *testing.T) {
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
defer func() {
|
||||
ErrorKey = "error"
|
||||
}()
|
||||
|
||||
err := fmt.Errorf("kaboom at layer %d", 4711)
|
||||
|
||||
assert.Equal(err, WithError(err).Data["error"])
|
||||
|
||||
logger := New()
|
||||
logger.Out = &bytes.Buffer{}
|
||||
entry := NewEntry(logger)
|
||||
|
||||
assert.Equal(err, entry.WithError(err).Data["error"])
|
||||
|
||||
ErrorKey = "err"
|
||||
|
||||
assert.Equal(err, entry.WithError(err).Data["err"])
|
||||
|
||||
}
|
||||
|
||||
func TestEntryPanicln(t *testing.T) {
|
||||
errBoom := fmt.Errorf("boom time")
|
||||
|
||||
defer func() {
|
||||
p := recover()
|
||||
assert.NotNil(t, p)
|
||||
|
||||
switch pVal := p.(type) {
|
||||
case *Entry:
|
||||
assert.Equal(t, "kaboom", pVal.Message)
|
||||
assert.Equal(t, errBoom, pVal.Data["err"])
|
||||
default:
|
||||
t.Fatalf("want type *Entry, got %T: %#v", pVal, pVal)
|
||||
}
|
||||
}()
|
||||
|
||||
logger := New()
|
||||
logger.Out = &bytes.Buffer{}
|
||||
entry := NewEntry(logger)
|
||||
entry.WithField("err", errBoom).Panicln("kaboom")
|
||||
}
|
||||
|
||||
func TestEntryPanicf(t *testing.T) {
|
||||
errBoom := fmt.Errorf("boom again")
|
||||
|
||||
defer func() {
|
||||
p := recover()
|
||||
assert.NotNil(t, p)
|
||||
|
||||
switch pVal := p.(type) {
|
||||
case *Entry:
|
||||
assert.Equal(t, "kaboom true", pVal.Message)
|
||||
assert.Equal(t, errBoom, pVal.Data["err"])
|
||||
default:
|
||||
t.Fatalf("want type *Entry, got %T: %#v", pVal, pVal)
|
||||
}
|
||||
}()
|
||||
|
||||
logger := New()
|
||||
logger.Out = &bytes.Buffer{}
|
||||
entry := NewEntry(logger)
|
||||
entry.WithField("err", errBoom).Panicf("kaboom %v", true)
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
// std is the name of the standard logger in stdlib `log`
|
||||
std = New()
|
||||
)
|
||||
|
||||
func StandardLogger() *Logger {
|
||||
return std
|
||||
}
|
||||
|
||||
// SetOutput sets the standard logger output.
|
||||
func SetOutput(out io.Writer) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Out = out
|
||||
}
|
||||
|
||||
// SetFormatter sets the standard logger formatter.
|
||||
func SetFormatter(formatter Formatter) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Formatter = formatter
|
||||
}
|
||||
|
||||
// SetLevel sets the standard logger level.
|
||||
func SetLevel(level Level) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Level = level
|
||||
}
|
||||
|
||||
// GetLevel returns the standard logger level.
|
||||
func GetLevel() Level {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
return std.Level
|
||||
}
|
||||
|
||||
// AddHook adds a hook to the standard logger hooks.
|
||||
func AddHook(hook Hook) {
|
||||
std.mu.Lock()
|
||||
defer std.mu.Unlock()
|
||||
std.Hooks.Add(hook)
|
||||
}
|
||||
|
||||
// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key.
|
||||
func WithError(err error) *Entry {
|
||||
return std.WithField(ErrorKey, err)
|
||||
}
|
||||
|
||||
// WithField creates an entry from the standard logger and adds a field to
|
||||
// it. If you want multiple fields, use `WithFields`.
|
||||
//
|
||||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
||||
// or Panic on the Entry it returns.
|
||||
func WithField(key string, value interface{}) *Entry {
|
||||
return std.WithField(key, value)
|
||||
}
|
||||
|
||||
// WithFields creates an entry from the standard logger and adds multiple
|
||||
// fields to it. This is simply a helper for `WithField`, invoking it
|
||||
// once for each field.
|
||||
//
|
||||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
||||
// or Panic on the Entry it returns.
|
||||
func WithFields(fields Fields) *Entry {
|
||||
return std.WithFields(fields)
|
||||
}
|
||||
|
||||
// Debug logs a message at level Debug on the standard logger.
|
||||
func Debug(args ...interface{}) {
|
||||
std.Debug(args...)
|
||||
}
|
||||
|
||||
// Print logs a message at level Info on the standard logger.
|
||||
func Print(args ...interface{}) {
|
||||
std.Print(args...)
|
||||
}
|
||||
|
||||
// Info logs a message at level Info on the standard logger.
|
||||
func Info(args ...interface{}) {
|
||||
std.Info(args...)
|
||||
}
|
||||
|
||||
// Warn logs a message at level Warn on the standard logger.
|
||||
func Warn(args ...interface{}) {
|
||||
std.Warn(args...)
|
||||
}
|
||||
|
||||
// Warning logs a message at level Warn on the standard logger.
|
||||
func Warning(args ...interface{}) {
|
||||
std.Warning(args...)
|
||||
}
|
||||
|
||||
// Error logs a message at level Error on the standard logger.
|
||||
func Error(args ...interface{}) {
|
||||
std.Error(args...)
|
||||
}
|
||||
|
||||
// Panic logs a message at level Panic on the standard logger.
|
||||
func Panic(args ...interface{}) {
|
||||
std.Panic(args...)
|
||||
}
|
||||
|
||||
// Fatal logs a message at level Fatal on the standard logger.
|
||||
func Fatal(args ...interface{}) {
|
||||
std.Fatal(args...)
|
||||
}
|
||||
|
||||
// Debugf logs a message at level Debug on the standard logger.
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
std.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Printf logs a message at level Info on the standard logger.
|
||||
func Printf(format string, args ...interface{}) {
|
||||
std.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs a message at level Info on the standard logger.
|
||||
func Infof(format string, args ...interface{}) {
|
||||
std.Infof(format, args...)
|
||||
}
|
||||
|
||||
// Warnf logs a message at level Warn on the standard logger.
|
||||
func Warnf(format string, args ...interface{}) {
|
||||
std.Warnf(format, args...)
|
||||
}
|
||||
|
||||
// Warningf logs a message at level Warn on the standard logger.
|
||||
func Warningf(format string, args ...interface{}) {
|
||||
std.Warningf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs a message at level Error on the standard logger.
|
||||
func Errorf(format string, args ...interface{}) {
|
||||
std.Errorf(format, args...)
|
||||
}
|
||||
|
||||
// Panicf logs a message at level Panic on the standard logger.
|
||||
func Panicf(format string, args ...interface{}) {
|
||||
std.Panicf(format, args...)
|
||||
}
|
||||
|
||||
// Fatalf logs a message at level Fatal on the standard logger.
|
||||
func Fatalf(format string, args ...interface{}) {
|
||||
std.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
// Debugln logs a message at level Debug on the standard logger.
|
||||
func Debugln(args ...interface{}) {
|
||||
std.Debugln(args...)
|
||||
}
|
||||
|
||||
// Println logs a message at level Info on the standard logger.
|
||||
func Println(args ...interface{}) {
|
||||
std.Println(args...)
|
||||
}
|
||||
|
||||
// Infoln logs a message at level Info on the standard logger.
|
||||
func Infoln(args ...interface{}) {
|
||||
std.Infoln(args...)
|
||||
}
|
||||
|
||||
// Warnln logs a message at level Warn on the standard logger.
|
||||
func Warnln(args ...interface{}) {
|
||||
std.Warnln(args...)
|
||||
}
|
||||
|
||||
// Warningln logs a message at level Warn on the standard logger.
|
||||
func Warningln(args ...interface{}) {
|
||||
std.Warningln(args...)
|
||||
}
|
||||
|
||||
// Errorln logs a message at level Error on the standard logger.
|
||||
func Errorln(args ...interface{}) {
|
||||
std.Errorln(args...)
|
||||
}
|
||||
|
||||
// Panicln logs a message at level Panic on the standard logger.
|
||||
func Panicln(args ...interface{}) {
|
||||
std.Panicln(args...)
|
||||
}
|
||||
|
||||
// Fatalln logs a message at level Fatal on the standard logger.
|
||||
func Fatalln(args ...interface{}) {
|
||||
std.Fatalln(args...)
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package logrus
|
||||
|
||||
import "time"
|
||||
|
||||
const DefaultTimestampFormat = time.RFC3339
|
||||
|
||||
// The Formatter interface is used to implement a custom Formatter. It takes an
|
||||
// `Entry`. It exposes all the fields, including the default ones:
|
||||
//
|
||||
// * `entry.Data["msg"]`. The message passed from Info, Warn, Error ..
|
||||
// * `entry.Data["time"]`. The timestamp.
|
||||
// * `entry.Data["level"]. The level the entry was logged at.
|
||||
//
|
||||
// Any additional fields added with `WithField` or `WithFields` are also in
|
||||
// `entry.Data`. Format is expected to return an array of bytes which are then
|
||||
// logged to `logger.Out`.
|
||||
type Formatter interface {
|
||||
Format(*Entry) ([]byte, error)
|
||||
}
|
||||
|
||||
// This is to not silently overwrite `time`, `msg` and `level` fields when
|
||||
// dumping it. If this code wasn't there doing:
|
||||
//
|
||||
// logrus.WithField("level", 1).Info("hello")
|
||||
//
|
||||
// Would just silently drop the user provided level. Instead with this code
|
||||
// it'll logged as:
|
||||
//
|
||||
// {"level": "info", "fields.level": 1, "msg": "hello", "time": "..."}
|
||||
//
|
||||
// It's not exported because it's still using Data in an opinionated way. It's to
|
||||
// avoid code duplication between the two default formatters.
|
||||
func prefixFieldClashes(data Fields) {
|
||||
if t, ok := data["time"]; ok {
|
||||
data["fields.time"] = t
|
||||
}
|
||||
|
||||
if m, ok := data["msg"]; ok {
|
||||
data["fields.msg"] = m
|
||||
}
|
||||
|
||||
if l, ok := data["level"]; ok {
|
||||
data["fields.level"] = l
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// smallFields is a small size data set for benchmarking
|
||||
var smallFields = Fields{
|
||||
"foo": "bar",
|
||||
"baz": "qux",
|
||||
"one": "two",
|
||||
"three": "four",
|
||||
}
|
||||
|
||||
// largeFields is a large size data set for benchmarking
|
||||
var largeFields = Fields{
|
||||
"foo": "bar",
|
||||
"baz": "qux",
|
||||
"one": "two",
|
||||
"three": "four",
|
||||
"five": "six",
|
||||
"seven": "eight",
|
||||
"nine": "ten",
|
||||
"eleven": "twelve",
|
||||
"thirteen": "fourteen",
|
||||
"fifteen": "sixteen",
|
||||
"seventeen": "eighteen",
|
||||
"nineteen": "twenty",
|
||||
"a": "b",
|
||||
"c": "d",
|
||||
"e": "f",
|
||||
"g": "h",
|
||||
"i": "j",
|
||||
"k": "l",
|
||||
"m": "n",
|
||||
"o": "p",
|
||||
"q": "r",
|
||||
"s": "t",
|
||||
"u": "v",
|
||||
"w": "x",
|
||||
"y": "z",
|
||||
"this": "will",
|
||||
"make": "thirty",
|
||||
"entries": "yeah",
|
||||
}
|
||||
|
||||
var errorFields = Fields{
|
||||
"foo": fmt.Errorf("bar"),
|
||||
"baz": fmt.Errorf("qux"),
|
||||
}
|
||||
|
||||
func BenchmarkErrorTextFormatter(b *testing.B) {
|
||||
doBenchmark(b, &TextFormatter{DisableColors: true}, errorFields)
|
||||
}
|
||||
|
||||
func BenchmarkSmallTextFormatter(b *testing.B) {
|
||||
doBenchmark(b, &TextFormatter{DisableColors: true}, smallFields)
|
||||
}
|
||||
|
||||
func BenchmarkLargeTextFormatter(b *testing.B) {
|
||||
doBenchmark(b, &TextFormatter{DisableColors: true}, largeFields)
|
||||
}
|
||||
|
||||
func BenchmarkSmallColoredTextFormatter(b *testing.B) {
|
||||
doBenchmark(b, &TextFormatter{ForceColors: true}, smallFields)
|
||||
}
|
||||
|
||||
func BenchmarkLargeColoredTextFormatter(b *testing.B) {
|
||||
doBenchmark(b, &TextFormatter{ForceColors: true}, largeFields)
|
||||
}
|
||||
|
||||
func BenchmarkSmallJSONFormatter(b *testing.B) {
|
||||
doBenchmark(b, &JSONFormatter{}, smallFields)
|
||||
}
|
||||
|
||||
func BenchmarkLargeJSONFormatter(b *testing.B) {
|
||||
doBenchmark(b, &JSONFormatter{}, largeFields)
|
||||
}
|
||||
|
||||
func doBenchmark(b *testing.B, formatter Formatter, fields Fields) {
|
||||
entry := &Entry{
|
||||
Time: time.Time{},
|
||||
Level: InfoLevel,
|
||||
Message: "message",
|
||||
Data: fields,
|
||||
}
|
||||
var d []byte
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
d, err = formatter.Format(entry)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.SetBytes(int64(len(d)))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestHook struct {
|
||||
Fired bool
|
||||
}
|
||||
|
||||
func (hook *TestHook) Fire(entry *Entry) error {
|
||||
hook.Fired = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hook *TestHook) Levels() []Level {
|
||||
return []Level{
|
||||
DebugLevel,
|
||||
InfoLevel,
|
||||
WarnLevel,
|
||||
ErrorLevel,
|
||||
FatalLevel,
|
||||
PanicLevel,
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookFires(t *testing.T) {
|
||||
hook := new(TestHook)
|
||||
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Hooks.Add(hook)
|
||||
assert.Equal(t, hook.Fired, false)
|
||||
|
||||
log.Print("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, hook.Fired, true)
|
||||
})
|
||||
}
|
||||
|
||||
type ModifyHook struct {
|
||||
}
|
||||
|
||||
func (hook *ModifyHook) Fire(entry *Entry) error {
|
||||
entry.Data["wow"] = "whale"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hook *ModifyHook) Levels() []Level {
|
||||
return []Level{
|
||||
DebugLevel,
|
||||
InfoLevel,
|
||||
WarnLevel,
|
||||
ErrorLevel,
|
||||
FatalLevel,
|
||||
PanicLevel,
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookCanModifyEntry(t *testing.T) {
|
||||
hook := new(ModifyHook)
|
||||
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Hooks.Add(hook)
|
||||
log.WithField("wow", "elephant").Print("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["wow"], "whale")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCanFireMultipleHooks(t *testing.T) {
|
||||
hook1 := new(ModifyHook)
|
||||
hook2 := new(TestHook)
|
||||
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Hooks.Add(hook1)
|
||||
log.Hooks.Add(hook2)
|
||||
|
||||
log.WithField("wow", "elephant").Print("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["wow"], "whale")
|
||||
assert.Equal(t, hook2.Fired, true)
|
||||
})
|
||||
}
|
||||
|
||||
type ErrorHook struct {
|
||||
Fired bool
|
||||
}
|
||||
|
||||
func (hook *ErrorHook) Fire(entry *Entry) error {
|
||||
hook.Fired = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hook *ErrorHook) Levels() []Level {
|
||||
return []Level{
|
||||
ErrorLevel,
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHookShouldntFireOnInfo(t *testing.T) {
|
||||
hook := new(ErrorHook)
|
||||
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Hooks.Add(hook)
|
||||
log.Info("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, hook.Fired, false)
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorHookShouldFireOnError(t *testing.T) {
|
||||
hook := new(ErrorHook)
|
||||
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Hooks.Add(hook)
|
||||
log.Error("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, hook.Fired, true)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package logrus
|
||||
|
||||
// A hook to be fired when logging on the logging levels returned from
|
||||
// `Levels()` on your implementation of the interface. Note that this is not
|
||||
// fired in a goroutine or a channel with workers, you should handle such
|
||||
// functionality yourself if your call is non-blocking and you don't wish for
|
||||
// the logging calls for levels returned from `Levels()` to block.
|
||||
type Hook interface {
|
||||
Levels() []Level
|
||||
Fire(*Entry) error
|
||||
}
|
||||
|
||||
// Internal type for storing the hooks on a logger instance.
|
||||
type LevelHooks map[Level][]Hook
|
||||
|
||||
// Add a hook to an instance of logger. This is called with
|
||||
// `log.Hooks.Add(new(MyHook))` where `MyHook` implements the `Hook` interface.
|
||||
func (hooks LevelHooks) Add(hook Hook) {
|
||||
for _, level := range hook.Levels() {
|
||||
hooks[level] = append(hooks[level], hook)
|
||||
}
|
||||
}
|
||||
|
||||
// Fire all the hooks for the passed level. Used by `entry.log` to fire
|
||||
// appropriate hooks for a log entry.
|
||||
func (hooks LevelHooks) Fire(level Level, entry *Entry) error {
|
||||
for _, hook := range hooks[level] {
|
||||
if err := hook.Fire(entry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type JSONFormatter struct {
|
||||
// TimestampFormat sets the format used for marshaling timestamps.
|
||||
TimestampFormat string
|
||||
}
|
||||
|
||||
func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
|
||||
data := make(Fields, len(entry.Data)+3)
|
||||
for k, v := range entry.Data {
|
||||
switch v := v.(type) {
|
||||
case error:
|
||||
// Otherwise errors are ignored by `encoding/json`
|
||||
// https://github.com/Sirupsen/logrus/issues/137
|
||||
data[k] = v.Error()
|
||||
default:
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
prefixFieldClashes(data)
|
||||
|
||||
timestampFormat := f.TimestampFormat
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = DefaultTimestampFormat
|
||||
}
|
||||
|
||||
data["time"] = entry.Time.Format(timestampFormat)
|
||||
data["msg"] = entry.Message
|
||||
data["level"] = entry.Level.String()
|
||||
|
||||
serialized, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err)
|
||||
}
|
||||
return append(serialized, '\n'), nil
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorNotLost(t *testing.T) {
|
||||
formatter := &JSONFormatter{}
|
||||
|
||||
b, err := formatter.Format(WithField("error", errors.New("wild walrus")))
|
||||
if err != nil {
|
||||
t.Fatal("Unable to format entry: ", err)
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
err = json.Unmarshal(b, &entry)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to unmarshal formatted entry: ", err)
|
||||
}
|
||||
|
||||
if entry["error"] != "wild walrus" {
|
||||
t.Fatal("Error field not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorNotLostOnFieldNotNamedError(t *testing.T) {
|
||||
formatter := &JSONFormatter{}
|
||||
|
||||
b, err := formatter.Format(WithField("omg", errors.New("wild walrus")))
|
||||
if err != nil {
|
||||
t.Fatal("Unable to format entry: ", err)
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
err = json.Unmarshal(b, &entry)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to unmarshal formatted entry: ", err)
|
||||
}
|
||||
|
||||
if entry["omg"] != "wild walrus" {
|
||||
t.Fatal("Error field not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldClashWithTime(t *testing.T) {
|
||||
formatter := &JSONFormatter{}
|
||||
|
||||
b, err := formatter.Format(WithField("time", "right now!"))
|
||||
if err != nil {
|
||||
t.Fatal("Unable to format entry: ", err)
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
err = json.Unmarshal(b, &entry)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to unmarshal formatted entry: ", err)
|
||||
}
|
||||
|
||||
if entry["fields.time"] != "right now!" {
|
||||
t.Fatal("fields.time not set to original time field")
|
||||
}
|
||||
|
||||
if entry["time"] != "0001-01-01T00:00:00Z" {
|
||||
t.Fatal("time field not set to current time, was: ", entry["time"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldClashWithMsg(t *testing.T) {
|
||||
formatter := &JSONFormatter{}
|
||||
|
||||
b, err := formatter.Format(WithField("msg", "something"))
|
||||
if err != nil {
|
||||
t.Fatal("Unable to format entry: ", err)
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
err = json.Unmarshal(b, &entry)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to unmarshal formatted entry: ", err)
|
||||
}
|
||||
|
||||
if entry["fields.msg"] != "something" {
|
||||
t.Fatal("fields.msg not set to original msg field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldClashWithLevel(t *testing.T) {
|
||||
formatter := &JSONFormatter{}
|
||||
|
||||
b, err := formatter.Format(WithField("level", "something"))
|
||||
if err != nil {
|
||||
t.Fatal("Unable to format entry: ", err)
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
err = json.Unmarshal(b, &entry)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to unmarshal formatted entry: ", err)
|
||||
}
|
||||
|
||||
if entry["fields.level"] != "something" {
|
||||
t.Fatal("fields.level not set to original level field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONEntryEndsWithNewline(t *testing.T) {
|
||||
formatter := &JSONFormatter{}
|
||||
|
||||
b, err := formatter.Format(WithField("level", "something"))
|
||||
if err != nil {
|
||||
t.Fatal("Unable to format entry: ", err)
|
||||
}
|
||||
|
||||
if b[len(b)-1] != '\n' {
|
||||
t.Fatal("Expected JSON log entry to end with a newline")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,308 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
// The logs are `io.Copy`'d to this in a mutex. It's common to set this to a
|
||||
// file, or leave it default which is `os.Stderr`. You can also set this to
|
||||
// something more adventorous, such as logging to Kafka.
|
||||
Out io.Writer
|
||||
// Hooks for the logger instance. These allow firing events based on logging
|
||||
// levels and log entries. For example, to send errors to an error tracking
|
||||
// service, log to StatsD or dump the core on fatal errors.
|
||||
Hooks LevelHooks
|
||||
// All log entries pass through the formatter before logged to Out. The
|
||||
// included formatters are `TextFormatter` and `JSONFormatter` for which
|
||||
// TextFormatter is the default. In development (when a TTY is attached) it
|
||||
// logs with colors, but to a file it wouldn't. You can easily implement your
|
||||
// own that implements the `Formatter` interface, see the `README` or included
|
||||
// formatters for examples.
|
||||
Formatter Formatter
|
||||
// The logging level the logger should log at. This is typically (and defaults
|
||||
// to) `logrus.Info`, which allows Info(), Warn(), Error() and Fatal() to be
|
||||
// logged. `logrus.Debug` is useful in
|
||||
Level Level
|
||||
// Used to sync writing to the log. Locking is enabled by Default
|
||||
mu MutexWrap
|
||||
// Reusable empty entry
|
||||
entryPool sync.Pool
|
||||
}
|
||||
|
||||
type MutexWrap struct {
|
||||
lock sync.Mutex
|
||||
disabled bool
|
||||
}
|
||||
|
||||
func (mw *MutexWrap) Lock() {
|
||||
if !mw.disabled {
|
||||
mw.lock.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (mw *MutexWrap) Unlock() {
|
||||
if !mw.disabled {
|
||||
mw.lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (mw *MutexWrap) Disable() {
|
||||
mw.disabled = true
|
||||
}
|
||||
|
||||
// Creates a new logger. Configuration should be set by changing `Formatter`,
|
||||
// `Out` and `Hooks` directly on the default logger instance. You can also just
|
||||
// instantiate your own:
|
||||
//
|
||||
// var log = &Logger{
|
||||
// Out: os.Stderr,
|
||||
// Formatter: new(JSONFormatter),
|
||||
// Hooks: make(LevelHooks),
|
||||
// Level: logrus.DebugLevel,
|
||||
// }
|
||||
//
|
||||
// It's recommended to make this a global instance called `log`.
|
||||
func New() *Logger {
|
||||
return &Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: new(TextFormatter),
|
||||
Hooks: make(LevelHooks),
|
||||
Level: InfoLevel,
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) newEntry() *Entry {
|
||||
entry, ok := logger.entryPool.Get().(*Entry)
|
||||
if ok {
|
||||
return entry
|
||||
}
|
||||
return NewEntry(logger)
|
||||
}
|
||||
|
||||
func (logger *Logger) releaseEntry(entry *Entry) {
|
||||
logger.entryPool.Put(entry)
|
||||
}
|
||||
|
||||
// Adds a field to the log entry, note that it doesn't log until you call
|
||||
// Debug, Print, Info, Warn, Fatal or Panic. It only creates a log entry.
|
||||
// If you want multiple fields, use `WithFields`.
|
||||
func (logger *Logger) WithField(key string, value interface{}) *Entry {
|
||||
entry := logger.newEntry()
|
||||
defer logger.releaseEntry(entry)
|
||||
return entry.WithField(key, value)
|
||||
}
|
||||
|
||||
// Adds a struct of fields to the log entry. All it does is call `WithField` for
|
||||
// each `Field`.
|
||||
func (logger *Logger) WithFields(fields Fields) *Entry {
|
||||
entry := logger.newEntry()
|
||||
defer logger.releaseEntry(entry)
|
||||
return entry.WithFields(fields)
|
||||
}
|
||||
|
||||
// Add an error as single field to the log entry. All it does is call
|
||||
// `WithError` for the given `error`.
|
||||
func (logger *Logger) WithError(err error) *Entry {
|
||||
entry := logger.newEntry()
|
||||
defer logger.releaseEntry(entry)
|
||||
return entry.WithError(err)
|
||||
}
|
||||
|
||||
func (logger *Logger) Debugf(format string, args ...interface{}) {
|
||||
if logger.Level >= DebugLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Debugf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Infof(format string, args ...interface{}) {
|
||||
if logger.Level >= InfoLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Infof(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Printf(format string, args ...interface{}) {
|
||||
entry := logger.newEntry()
|
||||
entry.Printf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
|
||||
func (logger *Logger) Warnf(format string, args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Warningf(format string, args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Errorf(format string, args ...interface{}) {
|
||||
if logger.Level >= ErrorLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Errorf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Fatalf(format string, args ...interface{}) {
|
||||
if logger.Level >= FatalLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Fatalf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (logger *Logger) Panicf(format string, args ...interface{}) {
|
||||
if logger.Level >= PanicLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Panicf(format, args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Debug(args ...interface{}) {
|
||||
if logger.Level >= DebugLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Debug(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Info(args ...interface{}) {
|
||||
if logger.Level >= InfoLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Info(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Print(args ...interface{}) {
|
||||
entry := logger.newEntry()
|
||||
entry.Info(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
|
||||
func (logger *Logger) Warn(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warn(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Warning(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warn(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Error(args ...interface{}) {
|
||||
if logger.Level >= ErrorLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Error(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Fatal(args ...interface{}) {
|
||||
if logger.Level >= FatalLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Fatal(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (logger *Logger) Panic(args ...interface{}) {
|
||||
if logger.Level >= PanicLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Panic(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Debugln(args ...interface{}) {
|
||||
if logger.Level >= DebugLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Debugln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Infoln(args ...interface{}) {
|
||||
if logger.Level >= InfoLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Infoln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Println(args ...interface{}) {
|
||||
entry := logger.newEntry()
|
||||
entry.Println(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
|
||||
func (logger *Logger) Warnln(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Warningln(args ...interface{}) {
|
||||
if logger.Level >= WarnLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Warnln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Errorln(args ...interface{}) {
|
||||
if logger.Level >= ErrorLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Errorln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (logger *Logger) Fatalln(args ...interface{}) {
|
||||
if logger.Level >= FatalLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Fatalln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
Exit(1)
|
||||
}
|
||||
|
||||
func (logger *Logger) Panicln(args ...interface{}) {
|
||||
if logger.Level >= PanicLevel {
|
||||
entry := logger.newEntry()
|
||||
entry.Panicln(args...)
|
||||
logger.releaseEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
//When file is opened with appending mode, it's safe to
|
||||
//write concurrently to a file (within 4k message on Linux).
|
||||
//In these cases user can choose to disable the lock.
|
||||
func (logger *Logger) SetNoLock() {
|
||||
logger.mu.Disable()
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// smallFields is a small size data set for benchmarking
|
||||
var loggerFields = Fields{
|
||||
"foo": "bar",
|
||||
"baz": "qux",
|
||||
"one": "two",
|
||||
"three": "four",
|
||||
}
|
||||
|
||||
func BenchmarkDummyLogger(b *testing.B) {
|
||||
nullf, err := os.OpenFile("/dev/null", os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
b.Fatalf("%v", err)
|
||||
}
|
||||
defer nullf.Close()
|
||||
doLoggerBenchmark(b, nullf, &TextFormatter{DisableColors: true}, smallFields)
|
||||
}
|
||||
|
||||
func BenchmarkDummyLoggerNoLock(b *testing.B) {
|
||||
nullf, err := os.OpenFile("/dev/null", os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
b.Fatalf("%v", err)
|
||||
}
|
||||
defer nullf.Close()
|
||||
doLoggerBenchmarkNoLock(b, nullf, &TextFormatter{DisableColors: true}, smallFields)
|
||||
}
|
||||
|
||||
func doLoggerBenchmark(b *testing.B, out *os.File, formatter Formatter, fields Fields) {
|
||||
logger := Logger{
|
||||
Out: out,
|
||||
Level: InfoLevel,
|
||||
Formatter: formatter,
|
||||
}
|
||||
entry := logger.WithFields(fields)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
entry.Info("aaa")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func doLoggerBenchmarkNoLock(b *testing.B, out *os.File, formatter Formatter, fields Fields) {
|
||||
logger := Logger{
|
||||
Out: out,
|
||||
Level: InfoLevel,
|
||||
Formatter: formatter,
|
||||
}
|
||||
logger.SetNoLock()
|
||||
entry := logger.WithFields(fields)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
entry.Info("aaa")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Fields type, used to pass to `WithFields`.
|
||||
type Fields map[string]interface{}
|
||||
|
||||
// Level type
|
||||
type Level uint8
|
||||
|
||||
// Convert the Level to a string. E.g. PanicLevel becomes "panic".
|
||||
func (level Level) String() string {
|
||||
switch level {
|
||||
case DebugLevel:
|
||||
return "debug"
|
||||
case InfoLevel:
|
||||
return "info"
|
||||
case WarnLevel:
|
||||
return "warning"
|
||||
case ErrorLevel:
|
||||
return "error"
|
||||
case FatalLevel:
|
||||
return "fatal"
|
||||
case PanicLevel:
|
||||
return "panic"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// ParseLevel takes a string level and returns the Logrus log level constant.
|
||||
func ParseLevel(lvl string) (Level, error) {
|
||||
switch strings.ToLower(lvl) {
|
||||
case "panic":
|
||||
return PanicLevel, nil
|
||||
case "fatal":
|
||||
return FatalLevel, nil
|
||||
case "error":
|
||||
return ErrorLevel, nil
|
||||
case "warn", "warning":
|
||||
return WarnLevel, nil
|
||||
case "info":
|
||||
return InfoLevel, nil
|
||||
case "debug":
|
||||
return DebugLevel, nil
|
||||
}
|
||||
|
||||
var l Level
|
||||
return l, fmt.Errorf("not a valid logrus Level: %q", lvl)
|
||||
}
|
||||
|
||||
// A constant exposing all logging levels
|
||||
var AllLevels = []Level{
|
||||
PanicLevel,
|
||||
FatalLevel,
|
||||
ErrorLevel,
|
||||
WarnLevel,
|
||||
InfoLevel,
|
||||
DebugLevel,
|
||||
}
|
||||
|
||||
// These are the different logging levels. You can set the logging level to log
|
||||
// on your instance of logger, obtained with `logrus.New()`.
|
||||
const (
|
||||
// PanicLevel level, highest level of severity. Logs and then calls panic with the
|
||||
// message passed to Debug, Info, ...
|
||||
PanicLevel Level = iota
|
||||
// FatalLevel level. Logs and then calls `os.Exit(1)`. It will exit even if the
|
||||
// logging level is set to Panic.
|
||||
FatalLevel
|
||||
// ErrorLevel level. Logs. Used for errors that should definitely be noted.
|
||||
// Commonly used for hooks to send errors to an error tracking service.
|
||||
ErrorLevel
|
||||
// WarnLevel level. Non-critical entries that deserve eyes.
|
||||
WarnLevel
|
||||
// InfoLevel level. General operational entries about what's going on inside the
|
||||
// application.
|
||||
InfoLevel
|
||||
// DebugLevel level. Usually only enabled when debugging. Very verbose logging.
|
||||
DebugLevel
|
||||
)
|
||||
|
||||
// Won't compile if StdLogger can't be realized by a log.Logger
|
||||
var (
|
||||
_ StdLogger = &log.Logger{}
|
||||
_ StdLogger = &Entry{}
|
||||
_ StdLogger = &Logger{}
|
||||
)
|
||||
|
||||
// StdLogger is what your logrus-enabled library should take, that way
|
||||
// it'll accept a stdlib logger and a logrus logger. There's no standard
|
||||
// interface, this is the closest we get, unfortunately.
|
||||
type StdLogger interface {
|
||||
Print(...interface{})
|
||||
Printf(string, ...interface{})
|
||||
Println(...interface{})
|
||||
|
||||
Fatal(...interface{})
|
||||
Fatalf(string, ...interface{})
|
||||
Fatalln(...interface{})
|
||||
|
||||
Panic(...interface{})
|
||||
Panicf(string, ...interface{})
|
||||
Panicln(...interface{})
|
||||
}
|
||||
|
||||
// The FieldLogger interface generalizes the Entry and Logger types
|
||||
type FieldLogger interface {
|
||||
WithField(key string, value interface{}) *Entry
|
||||
WithFields(fields Fields) *Entry
|
||||
WithError(err error) *Entry
|
||||
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Printf(format string, args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Warningf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Fatalf(format string, args ...interface{})
|
||||
Panicf(format string, args ...interface{})
|
||||
|
||||
Debug(args ...interface{})
|
||||
Info(args ...interface{})
|
||||
Print(args ...interface{})
|
||||
Warn(args ...interface{})
|
||||
Warning(args ...interface{})
|
||||
Error(args ...interface{})
|
||||
Fatal(args ...interface{})
|
||||
Panic(args ...interface{})
|
||||
|
||||
Debugln(args ...interface{})
|
||||
Infoln(args ...interface{})
|
||||
Println(args ...interface{})
|
||||
Warnln(args ...interface{})
|
||||
Warningln(args ...interface{})
|
||||
Errorln(args ...interface{})
|
||||
Fatalln(args ...interface{})
|
||||
Panicln(args ...interface{})
|
||||
}
|
|
@ -0,0 +1,361 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func LogAndAssertJSON(t *testing.T, log func(*Logger), assertions func(fields Fields)) {
|
||||
var buffer bytes.Buffer
|
||||
var fields Fields
|
||||
|
||||
logger := New()
|
||||
logger.Out = &buffer
|
||||
logger.Formatter = new(JSONFormatter)
|
||||
|
||||
log(logger)
|
||||
|
||||
err := json.Unmarshal(buffer.Bytes(), &fields)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assertions(fields)
|
||||
}
|
||||
|
||||
func LogAndAssertText(t *testing.T, log func(*Logger), assertions func(fields map[string]string)) {
|
||||
var buffer bytes.Buffer
|
||||
|
||||
logger := New()
|
||||
logger.Out = &buffer
|
||||
logger.Formatter = &TextFormatter{
|
||||
DisableColors: true,
|
||||
}
|
||||
|
||||
log(logger)
|
||||
|
||||
fields := make(map[string]string)
|
||||
for _, kv := range strings.Split(buffer.String(), " ") {
|
||||
if !strings.Contains(kv, "=") {
|
||||
continue
|
||||
}
|
||||
kvArr := strings.Split(kv, "=")
|
||||
key := strings.TrimSpace(kvArr[0])
|
||||
val := kvArr[1]
|
||||
if kvArr[1][0] == '"' {
|
||||
var err error
|
||||
val, err = strconv.Unquote(val)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
fields[key] = val
|
||||
}
|
||||
assertions(fields)
|
||||
}
|
||||
|
||||
func TestPrint(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Print("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test")
|
||||
assert.Equal(t, fields["level"], "info")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfo(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Info("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test")
|
||||
assert.Equal(t, fields["level"], "info")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWarn(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Warn("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test")
|
||||
assert.Equal(t, fields["level"], "warning")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfolnShouldAddSpacesBetweenStrings(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Infoln("test", "test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test test")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfolnShouldAddSpacesBetweenStringAndNonstring(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Infoln("test", 10)
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test 10")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfolnShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Infoln(10, 10)
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "10 10")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfoShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Infoln(10, 10)
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "10 10")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfoShouldNotAddSpacesBetweenStringAndNonstring(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Info("test", 10)
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test10")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfoShouldNotAddSpacesBetweenStrings(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.Info("test", "test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "testtest")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithFieldsShouldAllowAssignments(t *testing.T) {
|
||||
var buffer bytes.Buffer
|
||||
var fields Fields
|
||||
|
||||
logger := New()
|
||||
logger.Out = &buffer
|
||||
logger.Formatter = new(JSONFormatter)
|
||||
|
||||
localLog := logger.WithFields(Fields{
|
||||
"key1": "value1",
|
||||
})
|
||||
|
||||
localLog.WithField("key2", "value2").Info("test")
|
||||
err := json.Unmarshal(buffer.Bytes(), &fields)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "value2", fields["key2"])
|
||||
assert.Equal(t, "value1", fields["key1"])
|
||||
|
||||
buffer = bytes.Buffer{}
|
||||
fields = Fields{}
|
||||
localLog.Info("test")
|
||||
err = json.Unmarshal(buffer.Bytes(), &fields)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, ok := fields["key2"]
|
||||
assert.Equal(t, false, ok)
|
||||
assert.Equal(t, "value1", fields["key1"])
|
||||
}
|
||||
|
||||
func TestUserSuppliedFieldDoesNotOverwriteDefaults(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.WithField("msg", "hello").Info("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserSuppliedMsgFieldHasPrefix(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.WithField("msg", "hello").Info("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["msg"], "test")
|
||||
assert.Equal(t, fields["fields.msg"], "hello")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserSuppliedTimeFieldHasPrefix(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.WithField("time", "hello").Info("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["fields.time"], "hello")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserSuppliedLevelFieldHasPrefix(t *testing.T) {
|
||||
LogAndAssertJSON(t, func(log *Logger) {
|
||||
log.WithField("level", 1).Info("test")
|
||||
}, func(fields Fields) {
|
||||
assert.Equal(t, fields["level"], "info")
|
||||
assert.Equal(t, fields["fields.level"], 1.0) // JSON has floats only
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultFieldsAreNotPrefixed(t *testing.T) {
|
||||
LogAndAssertText(t, func(log *Logger) {
|
||||
ll := log.WithField("herp", "derp")
|
||||
ll.Info("hello")
|
||||
ll.Info("bye")
|
||||
}, func(fields map[string]string) {
|
||||
for _, fieldName := range []string{"fields.level", "fields.time", "fields.msg"} {
|
||||
if _, ok := fields[fieldName]; ok {
|
||||
t.Fatalf("should not have prefixed %q: %v", fieldName, fields)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDoubleLoggingDoesntPrefixPreviousFields(t *testing.T) {
|
||||
|
||||
var buffer bytes.Buffer
|
||||
var fields Fields
|
||||
|
||||
logger := New()
|
||||
logger.Out = &buffer
|
||||
logger.Formatter = new(JSONFormatter)
|
||||
|
||||
llog := logger.WithField("context", "eating raw fish")
|
||||
|
||||
llog.Info("looks delicious")
|
||||
|
||||
err := json.Unmarshal(buffer.Bytes(), &fields)
|
||||
assert.NoError(t, err, "should have decoded first message")
|
||||
assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields")
|
||||
assert.Equal(t, fields["msg"], "looks delicious")
|
||||
assert.Equal(t, fields["context"], "eating raw fish")
|
||||
|
||||
buffer.Reset()
|
||||
|
||||
llog.Warn("omg it is!")
|
||||
|
||||
err = json.Unmarshal(buffer.Bytes(), &fields)
|
||||
assert.NoError(t, err, "should have decoded second message")
|
||||
assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields")
|
||||
assert.Equal(t, fields["msg"], "omg it is!")
|
||||
assert.Equal(t, fields["context"], "eating raw fish")
|
||||
assert.Nil(t, fields["fields.msg"], "should not have prefixed previous `msg` entry")
|
||||
|
||||
}
|
||||
|
||||
func TestConvertLevelToString(t *testing.T) {
|
||||
assert.Equal(t, "debug", DebugLevel.String())
|
||||
assert.Equal(t, "info", InfoLevel.String())
|
||||
assert.Equal(t, "warning", WarnLevel.String())
|
||||
assert.Equal(t, "error", ErrorLevel.String())
|
||||
assert.Equal(t, "fatal", FatalLevel.String())
|
||||
assert.Equal(t, "panic", PanicLevel.String())
|
||||
}
|
||||
|
||||
func TestParseLevel(t *testing.T) {
|
||||
l, err := ParseLevel("panic")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, PanicLevel, l)
|
||||
|
||||
l, err = ParseLevel("PANIC")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, PanicLevel, l)
|
||||
|
||||
l, err = ParseLevel("fatal")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, FatalLevel, l)
|
||||
|
||||
l, err = ParseLevel("FATAL")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, FatalLevel, l)
|
||||
|
||||
l, err = ParseLevel("error")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, ErrorLevel, l)
|
||||
|
||||
l, err = ParseLevel("ERROR")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, ErrorLevel, l)
|
||||
|
||||
l, err = ParseLevel("warn")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, WarnLevel, l)
|
||||
|
||||
l, err = ParseLevel("WARN")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, WarnLevel, l)
|
||||
|
||||
l, err = ParseLevel("warning")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, WarnLevel, l)
|
||||
|
||||
l, err = ParseLevel("WARNING")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, WarnLevel, l)
|
||||
|
||||
l, err = ParseLevel("info")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, InfoLevel, l)
|
||||
|
||||
l, err = ParseLevel("INFO")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, InfoLevel, l)
|
||||
|
||||
l, err = ParseLevel("debug")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, DebugLevel, l)
|
||||
|
||||
l, err = ParseLevel("DEBUG")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, DebugLevel, l)
|
||||
|
||||
l, err = ParseLevel("invalid")
|
||||
assert.Equal(t, "not a valid logrus Level: \"invalid\"", err.Error())
|
||||
}
|
||||
|
||||
func TestGetSetLevelRace(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
if i%2 == 0 {
|
||||
SetLevel(InfoLevel)
|
||||
} else {
|
||||
GetLevel()
|
||||
}
|
||||
}(i)
|
||||
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestLoggingRace(t *testing.T) {
|
||||
logger := New()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
logger.Info("info")
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Compile test
|
||||
func TestLogrusInterface(t *testing.T) {
|
||||
var buffer bytes.Buffer
|
||||
fn := func(l FieldLogger) {
|
||||
b := l.WithField("key", "value")
|
||||
b.Debug("Test")
|
||||
}
|
||||
// test logger
|
||||
logger := New()
|
||||
logger.Out = &buffer
|
||||
fn(logger)
|
||||
|
||||
// test Entry
|
||||
e := logger.WithField("another", "value")
|
||||
fn(e)
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
// +build appengine
|
||||
|
||||
package logrus
|
||||
|
||||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
// +build darwin freebsd openbsd netbsd dragonfly
|
||||
// +build !appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import "syscall"
|
||||
|
||||
const ioctlReadTermios = syscall.TIOCGETA
|
||||
|
||||
type Termios syscall.Termios
|
|
@ -0,0 +1,14 @@
|
|||
// Based on ssh/terminal:
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import "syscall"
|
||||
|
||||
const ioctlReadTermios = syscall.TCGETS
|
||||
|
||||
type Termios syscall.Termios
|
|
@ -0,0 +1,22 @@
|
|||
// Based on ssh/terminal:
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build linux darwin freebsd openbsd netbsd dragonfly
|
||||
// +build !appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
fd := syscall.Stderr
|
||||
var termios Termios
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
|
||||
return err == 0
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// +build solaris,!appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
_, err := unix.IoctlGetTermios(int(os.Stdout.Fd()), unix.TCGETA)
|
||||
return err == nil
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
// Based on ssh/terminal:
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows,!appengine
|
||||
|
||||
package logrus
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
|
||||
var (
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
)
|
||||
|
||||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
||||
func IsTerminal() bool {
|
||||
fd := syscall.Stderr
|
||||
var st uint32
|
||||
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
||||
return r != 0 && e == 0
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
nocolor = 0
|
||||
red = 31
|
||||
green = 32
|
||||
yellow = 33
|
||||
blue = 34
|
||||
gray = 37
|
||||
)
|
||||
|
||||
var (
|
||||
baseTimestamp time.Time
|
||||
isTerminal bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
baseTimestamp = time.Now()
|
||||
isTerminal = IsTerminal()
|
||||
}
|
||||
|
||||
func miniTS() int {
|
||||
return int(time.Since(baseTimestamp) / time.Second)
|
||||
}
|
||||
|
||||
type TextFormatter struct {
|
||||
// Set to true to bypass checking for a TTY before outputting colors.
|
||||
ForceColors bool
|
||||
|
||||
// Force disabling colors.
|
||||
DisableColors bool
|
||||
|
||||
// Disable timestamp logging. useful when output is redirected to logging
|
||||
// system that already adds timestamps.
|
||||
DisableTimestamp bool
|
||||
|
||||
// Enable logging the full timestamp when a TTY is attached instead of just
|
||||
// the time passed since beginning of execution.
|
||||
FullTimestamp bool
|
||||
|
||||
// TimestampFormat to use for display when a full timestamp is printed
|
||||
TimestampFormat string
|
||||
|
||||
// The fields are sorted by default for a consistent output. For applications
|
||||
// that log extremely frequently and don't use the JSON formatter this may not
|
||||
// be desired.
|
||||
DisableSorting bool
|
||||
}
|
||||
|
||||
func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
|
||||
var b *bytes.Buffer
|
||||
var keys []string = make([]string, 0, len(entry.Data))
|
||||
for k := range entry.Data {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
if !f.DisableSorting {
|
||||
sort.Strings(keys)
|
||||
}
|
||||
if entry.Buffer != nil {
|
||||
b = entry.Buffer
|
||||
} else {
|
||||
b = &bytes.Buffer{}
|
||||
}
|
||||
|
||||
prefixFieldClashes(entry.Data)
|
||||
|
||||
isColorTerminal := isTerminal && (runtime.GOOS != "windows")
|
||||
isColored := (f.ForceColors || isColorTerminal) && !f.DisableColors
|
||||
|
||||
timestampFormat := f.TimestampFormat
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = DefaultTimestampFormat
|
||||
}
|
||||
if isColored {
|
||||
f.printColored(b, entry, keys, timestampFormat)
|
||||
} else {
|
||||
if !f.DisableTimestamp {
|
||||
f.appendKeyValue(b, "time", entry.Time.Format(timestampFormat))
|
||||
}
|
||||
f.appendKeyValue(b, "level", entry.Level.String())
|
||||
if entry.Message != "" {
|
||||
f.appendKeyValue(b, "msg", entry.Message)
|
||||
}
|
||||
for _, key := range keys {
|
||||
f.appendKeyValue(b, key, entry.Data[key])
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteByte('\n')
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []string, timestampFormat string) {
|
||||
var levelColor int
|
||||
switch entry.Level {
|
||||
case DebugLevel:
|
||||
levelColor = gray
|
||||
case WarnLevel:
|
||||
levelColor = yellow
|
||||
case ErrorLevel, FatalLevel, PanicLevel:
|
||||
levelColor = red
|
||||
default:
|
||||
levelColor = blue
|
||||
}
|
||||
|
||||
levelText := strings.ToUpper(entry.Level.String())[0:4]
|
||||
|
||||
if !f.FullTimestamp {
|
||||
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%04d] %-44s ", levelColor, levelText, miniTS(), entry.Message)
|
||||
} else {
|
||||
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s] %-44s ", levelColor, levelText, entry.Time.Format(timestampFormat), entry.Message)
|
||||
}
|
||||
for _, k := range keys {
|
||||
v := entry.Data[k]
|
||||
fmt.Fprintf(b, " \x1b[%dm%s\x1b[0m=%+v", levelColor, k, v)
|
||||
}
|
||||
}
|
||||
|
||||
func needsQuoting(text string) bool {
|
||||
for _, ch := range text {
|
||||
if !((ch >= 'a' && ch <= 'z') ||
|
||||
(ch >= 'A' && ch <= 'Z') ||
|
||||
(ch >= '0' && ch <= '9') ||
|
||||
ch == '-' || ch == '.') {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *TextFormatter) appendKeyValue(b *bytes.Buffer, key string, value interface{}) {
|
||||
|
||||
b.WriteString(key)
|
||||
b.WriteByte('=')
|
||||
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
if !needsQuoting(value) {
|
||||
b.WriteString(value)
|
||||
} else {
|
||||
fmt.Fprintf(b, "%q", value)
|
||||
}
|
||||
case error:
|
||||
errmsg := value.Error()
|
||||
if !needsQuoting(errmsg) {
|
||||
b.WriteString(errmsg)
|
||||
} else {
|
||||
fmt.Fprintf(b, "%q", value)
|
||||
}
|
||||
default:
|
||||
fmt.Fprint(b, value)
|
||||
}
|
||||
|
||||
b.WriteByte(' ')
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestQuoting(t *testing.T) {
|
||||
tf := &TextFormatter{DisableColors: true}
|
||||
|
||||
checkQuoting := func(q bool, value interface{}) {
|
||||
b, _ := tf.Format(WithField("test", value))
|
||||
idx := bytes.Index(b, ([]byte)("test="))
|
||||
cont := bytes.Contains(b[idx+5:], []byte{'"'})
|
||||
if cont != q {
|
||||
if q {
|
||||
t.Errorf("quoting expected for: %#v", value)
|
||||
} else {
|
||||
t.Errorf("quoting not expected for: %#v", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
checkQuoting(false, "abcd")
|
||||
checkQuoting(false, "v1.0")
|
||||
checkQuoting(false, "1234567890")
|
||||
checkQuoting(true, "/foobar")
|
||||
checkQuoting(true, "x y")
|
||||
checkQuoting(true, "x,y")
|
||||
checkQuoting(false, errors.New("invalid"))
|
||||
checkQuoting(true, errors.New("invalid argument"))
|
||||
}
|
||||
|
||||
func TestTimestampFormat(t *testing.T) {
|
||||
checkTimeStr := func(format string) {
|
||||
customFormatter := &TextFormatter{DisableColors: true, TimestampFormat: format}
|
||||
customStr, _ := customFormatter.Format(WithField("test", "test"))
|
||||
timeStart := bytes.Index(customStr, ([]byte)("time="))
|
||||
timeEnd := bytes.Index(customStr, ([]byte)("level="))
|
||||
timeStr := customStr[timeStart+5 : timeEnd-1]
|
||||
if timeStr[0] == '"' && timeStr[len(timeStr)-1] == '"' {
|
||||
timeStr = timeStr[1 : len(timeStr)-1]
|
||||
}
|
||||
if format == "" {
|
||||
format = time.RFC3339
|
||||
}
|
||||
_, e := time.Parse(format, (string)(timeStr))
|
||||
if e != nil {
|
||||
t.Errorf("time string \"%s\" did not match provided time format \"%s\": %s", timeStr, format, e)
|
||||
}
|
||||
}
|
||||
|
||||
checkTimeStr("2006-01-02T15:04:05.000000000Z07:00")
|
||||
checkTimeStr("Mon Jan _2 15:04:05 2006")
|
||||
checkTimeStr("")
|
||||
}
|
||||
|
||||
// TODO add tests for sorting etc., this requires a parser for the text
|
||||
// formatter output.
|
|
@ -0,0 +1,53 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func (logger *Logger) Writer() *io.PipeWriter {
|
||||
return logger.WriterLevel(InfoLevel)
|
||||
}
|
||||
|
||||
func (logger *Logger) WriterLevel(level Level) *io.PipeWriter {
|
||||
reader, writer := io.Pipe()
|
||||
|
||||
var printFunc func(args ...interface{})
|
||||
switch level {
|
||||
case DebugLevel:
|
||||
printFunc = logger.Debug
|
||||
case InfoLevel:
|
||||
printFunc = logger.Info
|
||||
case WarnLevel:
|
||||
printFunc = logger.Warn
|
||||
case ErrorLevel:
|
||||
printFunc = logger.Error
|
||||
case FatalLevel:
|
||||
printFunc = logger.Fatal
|
||||
case PanicLevel:
|
||||
printFunc = logger.Panic
|
||||
default:
|
||||
printFunc = logger.Print
|
||||
}
|
||||
|
||||
go logger.writerScanner(reader, printFunc)
|
||||
runtime.SetFinalizer(writer, writerFinalizer)
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
func (logger *Logger) writerScanner(reader *io.PipeReader, printFunc func(args ...interface{})) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
printFunc(scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.Errorf("Error while reading from Writer: %s", err)
|
||||
}
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
func writerFinalizer(writer *io.PipeWriter) {
|
||||
writer.Close()
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
*.prof
|
||||
*.test
|
||||
*.swp
|
||||
/bin/
|
|
@ -0,0 +1,20 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013 Ben Johnson
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -0,0 +1,18 @@
|
|||
BRANCH=`git rev-parse --abbrev-ref HEAD`
|
||||
COMMIT=`git rev-parse --short HEAD`
|
||||
GOLDFLAGS="-X main.branch $(BRANCH) -X main.commit $(COMMIT)"
|
||||
|
||||
default: build
|
||||
|
||||
race:
|
||||
@go test -v -race -test.run="TestSimulate_(100op|1000op)"
|
||||
|
||||
# go get github.com/kisielk/errcheck
|
||||
errcheck:
|
||||
@errcheck -ignorepkg=bytes -ignore=os:Remove github.com/boltdb/bolt
|
||||
|
||||
test:
|
||||
@go test -v -cover .
|
||||
@go test -v ./cmd/bolt
|
||||
|
||||
.PHONY: fmt test
|
|
@ -0,0 +1,852 @@
|
|||
Bolt [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.svg?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.svg)](https://godoc.org/github.com/boltdb/bolt) ![Version](https://img.shields.io/badge/version-1.2.1-green.svg)
|
||||
====
|
||||
|
||||
Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas]
|
||||
[LMDB project][lmdb]. The goal of the project is to provide a simple,
|
||||
fast, and reliable database for projects that don't require a full database
|
||||
server such as Postgres or MySQL.
|
||||
|
||||
Since Bolt is meant to be used as such a low-level piece of functionality,
|
||||
simplicity is key. The API will be small and only focus on getting values
|
||||
and setting values. That's it.
|
||||
|
||||
[hyc_symas]: https://twitter.com/hyc_symas
|
||||
[lmdb]: http://symas.com/mdb/
|
||||
|
||||
## Project Status
|
||||
|
||||
Bolt is stable and the API is fixed. Full unit test coverage and randomized
|
||||
black box testing are used to ensure database consistency and thread safety.
|
||||
Bolt is currently in high-load production environments serving databases as
|
||||
large as 1TB. Many companies such as Shopify and Heroku use Bolt-backed
|
||||
services every day.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Getting Started](#getting-started)
|
||||
- [Installing](#installing)
|
||||
- [Opening a database](#opening-a-database)
|
||||
- [Transactions](#transactions)
|
||||
- [Read-write transactions](#read-write-transactions)
|
||||
- [Read-only transactions](#read-only-transactions)
|
||||
- [Batch read-write transactions](#batch-read-write-transactions)
|
||||
- [Managing transactions manually](#managing-transactions-manually)
|
||||
- [Using buckets](#using-buckets)
|
||||
- [Using key/value pairs](#using-keyvalue-pairs)
|
||||
- [Autoincrementing integer for the bucket](#autoincrementing-integer-for-the-bucket)
|
||||
- [Iterating over keys](#iterating-over-keys)
|
||||
- [Prefix scans](#prefix-scans)
|
||||
- [Range scans](#range-scans)
|
||||
- [ForEach()](#foreach)
|
||||
- [Nested buckets](#nested-buckets)
|
||||
- [Database backups](#database-backups)
|
||||
- [Statistics](#statistics)
|
||||
- [Read-Only Mode](#read-only-mode)
|
||||
- [Mobile Use (iOS/Android)](#mobile-use-iosandroid)
|
||||
- [Resources](#resources)
|
||||
- [Comparison with other databases](#comparison-with-other-databases)
|
||||
- [Postgres, MySQL, & other relational databases](#postgres-mysql--other-relational-databases)
|
||||
- [LevelDB, RocksDB](#leveldb-rocksdb)
|
||||
- [LMDB](#lmdb)
|
||||
- [Caveats & Limitations](#caveats--limitations)
|
||||
- [Reading the Source](#reading-the-source)
|
||||
- [Other Projects Using Bolt](#other-projects-using-bolt)
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Installing
|
||||
|
||||
To start using Bolt, install Go and run `go get`:
|
||||
|
||||
```sh
|
||||
$ go get github.com/boltdb/bolt/...
|
||||
```
|
||||
|
||||
This will retrieve the library and install the `bolt` command line utility into
|
||||
your `$GOBIN` path.
|
||||
|
||||
|
||||
### Opening a database
|
||||
|
||||
The top-level object in Bolt is a `DB`. It is represented as a single file on
|
||||
your disk and represents a consistent snapshot of your data.
|
||||
|
||||
To open your database, simply use the `bolt.Open()` function:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Open the my.db data file in your current directory.
|
||||
// It will be created if it doesn't exist.
|
||||
db, err := bolt.Open("my.db", 0600, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Please note that Bolt obtains a file lock on the data file so multiple processes
|
||||
cannot open the same database at the same time. Opening an already open Bolt
|
||||
database will cause it to hang until the other process closes it. To prevent
|
||||
an indefinite wait you can pass a timeout option to the `Open()` function:
|
||||
|
||||
```go
|
||||
db, err := bolt.Open("my.db", 0600, &bolt.Options{Timeout: 1 * time.Second})
|
||||
```
|
||||
|
||||
|
||||
### Transactions
|
||||
|
||||
Bolt allows only one read-write transaction at a time but allows as many
|
||||
read-only transactions as you want at a time. Each transaction has a consistent
|
||||
view of the data as it existed when the transaction started.
|
||||
|
||||
Individual transactions and all objects created from them (e.g. buckets, keys)
|
||||
are not thread safe. To work with data in multiple goroutines you must start
|
||||
a transaction for each one or use locking to ensure only one goroutine accesses
|
||||
a transaction at a time. Creating transaction from the `DB` is thread safe.
|
||||
|
||||
Read-only transactions and read-write transactions should not depend on one
|
||||
another and generally shouldn't be opened simultaneously in the same goroutine.
|
||||
This can cause a deadlock as the read-write transaction needs to periodically
|
||||
re-map the data file but it cannot do so while a read-only transaction is open.
|
||||
|
||||
|
||||
#### Read-write transactions
|
||||
|
||||
To start a read-write transaction, you can use the `DB.Update()` function:
|
||||
|
||||
```go
|
||||
err := db.Update(func(tx *bolt.Tx) error {
|
||||
...
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
Inside the closure, you have a consistent view of the database. You commit the
|
||||
transaction by returning `nil` at the end. You can also rollback the transaction
|
||||
at any point by returning an error. All database operations are allowed inside
|
||||
a read-write transaction.
|
||||
|
||||
Always check the return error as it will report any disk failures that can cause
|
||||
your transaction to not complete. If you return an error within your closure
|
||||
it will be passed through.
|
||||
|
||||
|
||||
#### Read-only transactions
|
||||
|
||||
To start a read-only transaction, you can use the `DB.View()` function:
|
||||
|
||||
```go
|
||||
err := db.View(func(tx *bolt.Tx) error {
|
||||
...
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
You also get a consistent view of the database within this closure, however,
|
||||
no mutating operations are allowed within a read-only transaction. You can only
|
||||
retrieve buckets, retrieve values, and copy the database within a read-only
|
||||
transaction.
|
||||
|
||||
|
||||
#### Batch read-write transactions
|
||||
|
||||
Each `DB.Update()` waits for disk to commit the writes. This overhead
|
||||
can be minimized by combining multiple updates with the `DB.Batch()`
|
||||
function:
|
||||
|
||||
```go
|
||||
err := db.Batch(func(tx *bolt.Tx) error {
|
||||
...
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
Concurrent Batch calls are opportunistically combined into larger
|
||||
transactions. Batch is only useful when there are multiple goroutines
|
||||
calling it.
|
||||
|
||||
The trade-off is that `Batch` can call the given
|
||||
function multiple times, if parts of the transaction fail. The
|
||||
function must be idempotent and side effects must take effect only
|
||||
after a successful return from `DB.Batch()`.
|
||||
|
||||
For example: don't display messages from inside the function, instead
|
||||
set variables in the enclosing scope:
|
||||
|
||||
```go
|
||||
var id uint64
|
||||
err := db.Batch(func(tx *bolt.Tx) error {
|
||||
// Find last key in bucket, decode as bigendian uint64, increment
|
||||
// by one, encode back to []byte, and add new key.
|
||||
...
|
||||
id = newValue
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return ...
|
||||
}
|
||||
fmt.Println("Allocated ID %d", id)
|
||||
```
|
||||
|
||||
|
||||
#### Managing transactions manually
|
||||
|
||||
The `DB.View()` and `DB.Update()` functions are wrappers around the `DB.Begin()`
|
||||
function. These helper functions will start the transaction, execute a function,
|
||||
and then safely close your transaction if an error is returned. This is the
|
||||
recommended way to use Bolt transactions.
|
||||
|
||||
However, sometimes you may want to manually start and end your transactions.
|
||||
You can use the `Tx.Begin()` function directly but **please** be sure to close
|
||||
the transaction.
|
||||
|
||||
```go
|
||||
// Start a writable transaction.
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Use the transaction...
|
||||
_, err := tx.CreateBucket([]byte("MyBucket"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit the transaction and check for error.
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
The first argument to `DB.Begin()` is a boolean stating if the transaction
|
||||
should be writable.
|
||||
|
||||
|
||||
### Using buckets
|
||||
|
||||
Buckets are collections of key/value pairs within the database. All keys in a
|
||||
bucket must be unique. You can create a bucket using the `DB.CreateBucket()`
|
||||
function:
|
||||
|
||||
```go
|
||||
db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("MyBucket"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create bucket: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
You can also create a bucket only if it doesn't exist by using the
|
||||
`Tx.CreateBucketIfNotExists()` function. It's a common pattern to call this
|
||||
function for all your top-level buckets after you open your database so you can
|
||||
guarantee that they exist for future transactions.
|
||||
|
||||
To delete a bucket, simply call the `Tx.DeleteBucket()` function.
|
||||
|
||||
|
||||
### Using key/value pairs
|
||||
|
||||
To save a key/value pair to a bucket, use the `Bucket.Put()` function:
|
||||
|
||||
```go
|
||||
db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte("MyBucket"))
|
||||
err := b.Put([]byte("answer"), []byte("42"))
|
||||
return err
|
||||
})
|
||||
```
|
||||
|
||||
This will set the value of the `"answer"` key to `"42"` in the `MyBucket`
|
||||
bucket. To retrieve this value, we can use the `Bucket.Get()` function:
|
||||
|
||||
```go
|
||||
db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte("MyBucket"))
|
||||
v := b.Get([]byte("answer"))
|
||||
fmt.Printf("The answer is: %s\n", v)
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
The `Get()` function does not return an error because its operation is
|
||||
guaranteed to work (unless there is some kind of system failure). If the key
|
||||
exists then it will return its byte slice value. If it doesn't exist then it
|
||||
will return `nil`. It's important to note that you can have a zero-length value
|
||||
set to a key which is different than the key not existing.
|
||||
|
||||
Use the `Bucket.Delete()` function to delete a key from the bucket.
|
||||
|
||||
Please note that values returned from `Get()` are only valid while the
|
||||
transaction is open. If you need to use a value outside of the transaction
|
||||
then you must use `copy()` to copy it to another byte slice.
|
||||
|
||||
|
||||
### Autoincrementing integer for the bucket
|
||||
By using the `NextSequence()` function, you can let Bolt determine a sequence
|
||||
which can be used as the unique identifier for your key/value pairs. See the
|
||||
example below.
|
||||
|
||||
```go
|
||||
// CreateUser saves u to the store. The new user ID is set on u once the data is persisted.
|
||||
func (s *Store) CreateUser(u *User) error {
|
||||
return s.db.Update(func(tx *bolt.Tx) error {
|
||||
// Retrieve the users bucket.
|
||||
// This should be created when the DB is first opened.
|
||||
b := tx.Bucket([]byte("users"))
|
||||
|
||||
// Generate ID for the user.
|
||||
// This returns an error only if the Tx is closed or not writeable.
|
||||
// That can't happen in an Update() call so I ignore the error check.
|
||||
id, _ := b.NextSequence()
|
||||
u.ID = int(id)
|
||||
|
||||
// Marshal user data into bytes.
|
||||
buf, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist bytes to users bucket.
|
||||
return b.Put(itob(u.ID), buf)
|
||||
})
|
||||
}
|
||||
|
||||
// itob returns an 8-byte big endian representation of v.
|
||||
func itob(v int) []byte {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(v))
|
||||
return b
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
### Iterating over keys
|
||||
|
||||
Bolt stores its keys in byte-sorted order within a bucket. This makes sequential
|
||||
iteration over these keys extremely fast. To iterate over keys we'll use a
|
||||
`Cursor`:
|
||||
|
||||
```go
|
||||
db.View(func(tx *bolt.Tx) error {
|
||||
// Assume bucket exists and has keys
|
||||
b := tx.Bucket([]byte("MyBucket"))
|
||||
|
||||
c := b.Cursor()
|
||||
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
fmt.Printf("key=%s, value=%s\n", k, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
The cursor allows you to move to a specific point in the list of keys and move
|
||||
forward or backward through the keys one at a time.
|
||||
|
||||
The following functions are available on the cursor:
|
||||
|
||||
```
|
||||
First() Move to the first key.
|
||||
Last() Move to the last key.
|
||||
Seek() Move to a specific key.
|
||||
Next() Move to the next key.
|
||||
Prev() Move to the previous key.
|
||||
```
|
||||
|
||||
Each of those functions has a return signature of `(key []byte, value []byte)`.
|
||||
When you have iterated to the end of the cursor then `Next()` will return a
|
||||
`nil` key. You must seek to a position using `First()`, `Last()`, or `Seek()`
|
||||
before calling `Next()` or `Prev()`. If you do not seek to a position then
|
||||
these functions will return a `nil` key.
|
||||
|
||||
During iteration, if the key is non-`nil` but the value is `nil`, that means
|
||||
the key refers to a bucket rather than a value. Use `Bucket.Bucket()` to
|
||||
access the sub-bucket.
|
||||
|
||||
|
||||
#### Prefix scans
|
||||
|
||||
To iterate over a key prefix, you can combine `Seek()` and `bytes.HasPrefix()`:
|
||||
|
||||
```go
|
||||
db.View(func(tx *bolt.Tx) error {
|
||||
// Assume bucket exists and has keys
|
||||
c := tx.Bucket([]byte("MyBucket")).Cursor()
|
||||
|
||||
prefix := []byte("1234")
|
||||
for k, v := c.Seek(prefix); bytes.HasPrefix(k, prefix); k, v = c.Next() {
|
||||
fmt.Printf("key=%s, value=%s\n", k, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
#### Range scans
|
||||
|
||||
Another common use case is scanning over a range such as a time range. If you
|
||||
use a sortable time encoding such as RFC3339 then you can query a specific
|
||||
date range like this:
|
||||
|
||||
```go
|
||||
db.View(func(tx *bolt.Tx) error {
|
||||
// Assume our events bucket exists and has RFC3339 encoded time keys.
|
||||
c := tx.Bucket([]byte("Events")).Cursor()
|
||||
|
||||
// Our time range spans the 90's decade.
|
||||
min := []byte("1990-01-01T00:00:00Z")
|
||||
max := []byte("2000-01-01T00:00:00Z")
|
||||
|
||||
// Iterate over the 90's.
|
||||
for k, v := c.Seek(min); k != nil && bytes.Compare(k, max) <= 0; k, v = c.Next() {
|
||||
fmt.Printf("%s: %s\n", k, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
Note that, while RFC3339 is sortable, the Golang implementation of RFC3339Nano does not use a fixed number of digits after the decimal point and is therefore not sortable.
|
||||
|
||||
|
||||
#### ForEach()
|
||||
|
||||
You can also use the function `ForEach()` if you know you'll be iterating over
|
||||
all the keys in a bucket:
|
||||
|
||||
```go
|
||||
db.View(func(tx *bolt.Tx) error {
|
||||
// Assume bucket exists and has keys
|
||||
b := tx.Bucket([]byte("MyBucket"))
|
||||
|
||||
b.ForEach(func(k, v []byte) error {
|
||||
fmt.Printf("key=%s, value=%s\n", k, v)
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
|
||||
### Nested buckets
|
||||
|
||||
You can also store a bucket in a key to create nested buckets. The API is the
|
||||
same as the bucket management API on the `DB` object:
|
||||
|
||||
```go
|
||||
func (*Bucket) CreateBucket(key []byte) (*Bucket, error)
|
||||
func (*Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error)
|
||||
func (*Bucket) DeleteBucket(key []byte) error
|
||||
```
|
||||
|
||||
|
||||
### Database backups
|
||||
|
||||
Bolt is a single file so it's easy to backup. You can use the `Tx.WriteTo()`
|
||||
function to write a consistent view of the database to a writer. If you call
|
||||
this from a read-only transaction, it will perform a hot backup and not block
|
||||
your other database reads and writes.
|
||||
|
||||
By default, it will use a regular file handle which will utilize the operating
|
||||
system's page cache. See the [`Tx`](https://godoc.org/github.com/boltdb/bolt#Tx)
|
||||
documentation for information about optimizing for larger-than-RAM datasets.
|
||||
|
||||
One common use case is to backup over HTTP so you can use tools like `cURL` to
|
||||
do database backups:
|
||||
|
||||
```go
|
||||
func BackupHandleFunc(w http.ResponseWriter, req *http.Request) {
|
||||
err := db.View(func(tx *bolt.Tx) error {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="my.db"`)
|
||||
w.Header().Set("Content-Length", strconv.Itoa(int(tx.Size())))
|
||||
_, err := tx.WriteTo(w)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Then you can backup using this command:
|
||||
|
||||
```sh
|
||||
$ curl http://localhost/backup > my.db
|
||||
```
|
||||
|
||||
Or you can open your browser to `http://localhost/backup` and it will download
|
||||
automatically.
|
||||
|
||||
If you want to backup to another file you can use the `Tx.CopyFile()` helper
|
||||
function.
|
||||
|
||||
|
||||
### Statistics
|
||||
|
||||
The database keeps a running count of many of the internal operations it
|
||||
performs so you can better understand what's going on. By grabbing a snapshot
|
||||
of these stats at two points in time we can see what operations were performed
|
||||
in that time range.
|
||||
|
||||
For example, we could start a goroutine to log stats every 10 seconds:
|
||||
|
||||
```go
|
||||
go func() {
|
||||
// Grab the initial stats.
|
||||
prev := db.Stats()
|
||||
|
||||
for {
|
||||
// Wait for 10s.
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Grab the current stats and diff them.
|
||||
stats := db.Stats()
|
||||
diff := stats.Sub(&prev)
|
||||
|
||||
// Encode stats to JSON and print to STDERR.
|
||||
json.NewEncoder(os.Stderr).Encode(diff)
|
||||
|
||||
// Save stats for the next loop.
|
||||
prev = stats
|
||||
}
|
||||
}()
|
||||
```
|
||||
|
||||
It's also useful to pipe these stats to a service such as statsd for monitoring
|
||||
or to provide an HTTP endpoint that will perform a fixed-length sample.
|
||||
|
||||
|
||||
### Read-Only Mode
|
||||
|
||||
Sometimes it is useful to create a shared, read-only Bolt database. To this,
|
||||
set the `Options.ReadOnly` flag when opening your database. Read-only mode
|
||||
uses a shared lock to allow multiple processes to read from the database but
|
||||
it will block any processes from opening the database in read-write mode.
|
||||
|
||||
```go
|
||||
db, err := bolt.Open("my.db", 0666, &bolt.Options{ReadOnly: true})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
### Mobile Use (iOS/Android)
|
||||
|
||||
Bolt is able to run on mobile devices by leveraging the binding feature of the
|
||||
[gomobile](https://github.com/golang/mobile) tool. Create a struct that will
|
||||
contain your database logic and a reference to a `*bolt.DB` with a initializing
|
||||
constructor that takes in a filepath where the database file will be stored.
|
||||
Neither Android nor iOS require extra permissions or cleanup from using this method.
|
||||
|
||||
```go
|
||||
func NewBoltDB(filepath string) *BoltDB {
|
||||
db, err := bolt.Open(filepath+"/demo.db", 0600, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return &BoltDB{db}
|
||||
}
|
||||
|
||||
type BoltDB struct {
|
||||
db *bolt.DB
|
||||
...
|
||||
}
|
||||
|
||||
func (b *BoltDB) Path() string {
|
||||
return b.db.Path()
|
||||
}
|
||||
|
||||
func (b *BoltDB) Close() {
|
||||
b.db.Close()
|
||||
}
|
||||
```
|
||||
|
||||
Database logic should be defined as methods on this wrapper struct.
|
||||
|
||||
To initialize this struct from the native language (both platforms now sync
|
||||
their local storage to the cloud. These snippets disable that functionality for the
|
||||
database file):
|
||||
|
||||
#### Android
|
||||
|
||||
```java
|
||||
String path;
|
||||
if (android.os.Build.VERSION.SDK_INT >=android.os.Build.VERSION_CODES.LOLLIPOP){
|
||||
path = getNoBackupFilesDir().getAbsolutePath();
|
||||
} else{
|
||||
path = getFilesDir().getAbsolutePath();
|
||||
}
|
||||
Boltmobiledemo.BoltDB boltDB = Boltmobiledemo.NewBoltDB(path)
|
||||
```
|
||||
|
||||
#### iOS
|
||||
|
||||
```objc
|
||||
- (void)demo {
|
||||
NSString* path = [NSSearchPathForDirectoriesInDomains(NSLibraryDirectory,
|
||||
NSUserDomainMask,
|
||||
YES) objectAtIndex:0];
|
||||
GoBoltmobiledemoBoltDB * demo = GoBoltmobiledemoNewBoltDB(path);
|
||||
[self addSkipBackupAttributeToItemAtPath:demo.path];
|
||||
//Some DB Logic would go here
|
||||
[demo close];
|
||||
}
|
||||
|
||||
- (BOOL)addSkipBackupAttributeToItemAtPath:(NSString *) filePathString
|
||||
{
|
||||
NSURL* URL= [NSURL fileURLWithPath: filePathString];
|
||||
assert([[NSFileManager defaultManager] fileExistsAtPath: [URL path]]);
|
||||
|
||||
NSError *error = nil;
|
||||
BOOL success = [URL setResourceValue: [NSNumber numberWithBool: YES]
|
||||
forKey: NSURLIsExcludedFromBackupKey error: &error];
|
||||
if(!success){
|
||||
NSLog(@"Error excluding %@ from backup %@", [URL lastPathComponent], error);
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
For more information on getting started with Bolt, check out the following articles:
|
||||
|
||||
* [Intro to BoltDB: Painless Performant Persistence](http://npf.io/2014/07/intro-to-boltdb-painless-performant-persistence/) by [Nate Finch](https://github.com/natefinch).
|
||||
* [Bolt -- an embedded key/value database for Go](https://www.progville.com/go/bolt-embedded-db-golang/) by Progville
|
||||
|
||||
|
||||
## Comparison with other databases
|
||||
|
||||
### Postgres, MySQL, & other relational databases
|
||||
|
||||
Relational databases structure data into rows and are only accessible through
|
||||
the use of SQL. This approach provides flexibility in how you store and query
|
||||
your data but also incurs overhead in parsing and planning SQL statements. Bolt
|
||||
accesses all data by a byte slice key. This makes Bolt fast to read and write
|
||||
data by key but provides no built-in support for joining values together.
|
||||
|
||||
Most relational databases (with the exception of SQLite) are standalone servers
|
||||
that run separately from your application. This gives your systems
|
||||
flexibility to connect multiple application servers to a single database
|
||||
server but also adds overhead in serializing and transporting data over the
|
||||
network. Bolt runs as a library included in your application so all data access
|
||||
has to go through your application's process. This brings data closer to your
|
||||
application but limits multi-process access to the data.
|
||||
|
||||
|
||||
### LevelDB, RocksDB
|
||||
|
||||
LevelDB and its derivatives (RocksDB, HyperLevelDB) are similar to Bolt in that
|
||||
they are libraries bundled into the application, however, their underlying
|
||||
structure is a log-structured merge-tree (LSM tree). An LSM tree optimizes
|
||||
random writes by using a write ahead log and multi-tiered, sorted files called
|
||||
SSTables. Bolt uses a B+tree internally and only a single file. Both approaches
|
||||
have trade-offs.
|
||||
|
||||
If you require a high random write throughput (>10,000 w/sec) or you need to use
|
||||
spinning disks then LevelDB could be a good choice. If your application is
|
||||
read-heavy or does a lot of range scans then Bolt could be a good choice.
|
||||
|
||||
One other important consideration is that LevelDB does not have transactions.
|
||||
It supports batch writing of key/values pairs and it supports read snapshots
|
||||
but it will not give you the ability to do a compare-and-swap operation safely.
|
||||
Bolt supports fully serializable ACID transactions.
|
||||
|
||||
|
||||
### LMDB
|
||||
|
||||
Bolt was originally a port of LMDB so it is architecturally similar. Both use
|
||||
a B+tree, have ACID semantics with fully serializable transactions, and support
|
||||
lock-free MVCC using a single writer and multiple readers.
|
||||
|
||||
The two projects have somewhat diverged. LMDB heavily focuses on raw performance
|
||||
while Bolt has focused on simplicity and ease of use. For example, LMDB allows
|
||||
several unsafe actions such as direct writes for the sake of performance. Bolt
|
||||
opts to disallow actions which can leave the database in a corrupted state. The
|
||||
only exception to this in Bolt is `DB.NoSync`.
|
||||
|
||||
There are also a few differences in API. LMDB requires a maximum mmap size when
|
||||
opening an `mdb_env` whereas Bolt will handle incremental mmap resizing
|
||||
automatically. LMDB overloads the getter and setter functions with multiple
|
||||
flags whereas Bolt splits these specialized cases into their own functions.
|
||||
|
||||
|
||||
## Caveats & Limitations
|
||||
|
||||
It's important to pick the right tool for the job and Bolt is no exception.
|
||||
Here are a few things to note when evaluating and using Bolt:
|
||||
|
||||
* Bolt is good for read intensive workloads. Sequential write performance is
|
||||
also fast but random writes can be slow. You can use `DB.Batch()` or add a
|
||||
write-ahead log to help mitigate this issue.
|
||||
|
||||
* Bolt uses a B+tree internally so there can be a lot of random page access.
|
||||
SSDs provide a significant performance boost over spinning disks.
|
||||
|
||||
* Try to avoid long running read transactions. Bolt uses copy-on-write so
|
||||
old pages cannot be reclaimed while an old transaction is using them.
|
||||
|
||||
* Byte slices returned from Bolt are only valid during a transaction. Once the
|
||||
transaction has been committed or rolled back then the memory they point to
|
||||
can be reused by a new page or can be unmapped from virtual memory and you'll
|
||||
see an `unexpected fault address` panic when accessing it.
|
||||
|
||||
* Be careful when using `Bucket.FillPercent`. Setting a high fill percent for
|
||||
buckets that have random inserts will cause your database to have very poor
|
||||
page utilization.
|
||||
|
||||
* Use larger buckets in general. Smaller buckets causes poor page utilization
|
||||
once they become larger than the page size (typically 4KB).
|
||||
|
||||
* Bulk loading a lot of random writes into a new bucket can be slow as the
|
||||
page will not split until the transaction is committed. Randomly inserting
|
||||
more than 100,000 key/value pairs into a single new bucket in a single
|
||||
transaction is not advised.
|
||||
|
||||
* Bolt uses a memory-mapped file so the underlying operating system handles the
|
||||
caching of the data. Typically, the OS will cache as much of the file as it
|
||||
can in memory and will release memory as needed to other processes. This means
|
||||
that Bolt can show very high memory usage when working with large databases.
|
||||
However, this is expected and the OS will release memory as needed. Bolt can
|
||||
handle databases much larger than the available physical RAM, provided its
|
||||
memory-map fits in the process virtual address space. It may be problematic
|
||||
on 32-bits systems.
|
||||
|
||||
* The data structures in the Bolt database are memory mapped so the data file
|
||||
will be endian specific. This means that you cannot copy a Bolt file from a
|
||||
little endian machine to a big endian machine and have it work. For most
|
||||
users this is not a concern since most modern CPUs are little endian.
|
||||
|
||||
* Because of the way pages are laid out on disk, Bolt cannot truncate data files
|
||||
and return free pages back to the disk. Instead, Bolt maintains a free list
|
||||
of unused pages within its data file. These free pages can be reused by later
|
||||
transactions. This works well for many use cases as databases generally tend
|
||||
to grow. However, it's important to note that deleting large chunks of data
|
||||
will not allow you to reclaim that space on disk.
|
||||
|
||||
For more information on page allocation, [see this comment][page-allocation].
|
||||
|
||||
[page-allocation]: https://github.com/boltdb/bolt/issues/308#issuecomment-74811638
|
||||
|
||||
|
||||
## Reading the Source
|
||||
|
||||
Bolt is a relatively small code base (<3KLOC) for an embedded, serializable,
|
||||
transactional key/value database so it can be a good starting point for people
|
||||
interested in how databases work.
|
||||
|
||||
The best places to start are the main entry points into Bolt:
|
||||
|
||||
- `Open()` - Initializes the reference to the database. It's responsible for
|
||||
creating the database if it doesn't exist, obtaining an exclusive lock on the
|
||||
file, reading the meta pages, & memory-mapping the file.
|
||||
|
||||
- `DB.Begin()` - Starts a read-only or read-write transaction depending on the
|
||||
value of the `writable` argument. This requires briefly obtaining the "meta"
|
||||
lock to keep track of open transactions. Only one read-write transaction can
|
||||
exist at a time so the "rwlock" is acquired during the life of a read-write
|
||||
transaction.
|
||||
|
||||
- `Bucket.Put()` - Writes a key/value pair into a bucket. After validating the
|
||||
arguments, a cursor is used to traverse the B+tree to the page and position
|
||||
where they key & value will be written. Once the position is found, the bucket
|
||||
materializes the underlying page and the page's parent pages into memory as
|
||||
"nodes". These nodes are where mutations occur during read-write transactions.
|
||||
These changes get flushed to disk during commit.
|
||||
|
||||
- `Bucket.Get()` - Retrieves a key/value pair from a bucket. This uses a cursor
|
||||
to move to the page & position of a key/value pair. During a read-only
|
||||
transaction, the key and value data is returned as a direct reference to the
|
||||
underlying mmap file so there's no allocation overhead. For read-write
|
||||
transactions, this data may reference the mmap file or one of the in-memory
|
||||
node values.
|
||||
|
||||
- `Cursor` - This object is simply for traversing the B+tree of on-disk pages
|
||||
or in-memory nodes. It can seek to a specific key, move to the first or last
|
||||
value, or it can move forward or backward. The cursor handles the movement up
|
||||
and down the B+tree transparently to the end user.
|
||||
|
||||
- `Tx.Commit()` - Converts the in-memory dirty nodes and the list of free pages
|
||||
into pages to be written to disk. Writing to disk then occurs in two phases.
|
||||
First, the dirty pages are written to disk and an `fsync()` occurs. Second, a
|
||||
new meta page with an incremented transaction ID is written and another
|
||||
`fsync()` occurs. This two phase write ensures that partially written data
|
||||
pages are ignored in the event of a crash since the meta page pointing to them
|
||||
is never written. Partially written meta pages are invalidated because they
|
||||
are written with a checksum.
|
||||
|
||||
If you have additional notes that could be helpful for others, please submit
|
||||
them via pull request.
|
||||
|
||||
|
||||
## Other Projects Using Bolt
|
||||
|
||||
Below is a list of public, open source projects that use Bolt:
|
||||
|
||||
* [BoltDbWeb](https://github.com/evnix/boltdbweb) - A web based GUI for BoltDB files.
|
||||
* [Operation Go: A Routine Mission](http://gocode.io) - An online programming game for Golang using Bolt for user accounts and a leaderboard.
|
||||
* [Bazil](https://bazil.org/) - A file system that lets your data reside where it is most convenient for it to reside.
|
||||
* [DVID](https://github.com/janelia-flyem/dvid) - Added Bolt as optional storage engine and testing it against Basho-tuned leveldb.
|
||||
* [Skybox Analytics](https://github.com/skybox/skybox) - A standalone funnel analysis tool for web analytics.
|
||||
* [Scuttlebutt](https://github.com/benbjohnson/scuttlebutt) - Uses Bolt to store and process all Twitter mentions of GitHub projects.
|
||||
* [Wiki](https://github.com/peterhellberg/wiki) - A tiny wiki using Goji, BoltDB and Blackfriday.
|
||||
* [ChainStore](https://github.com/pressly/chainstore) - Simple key-value interface to a variety of storage engines organized as a chain of operations.
|
||||
* [MetricBase](https://github.com/msiebuhr/MetricBase) - Single-binary version of Graphite.
|
||||
* [Gitchain](https://github.com/gitchain/gitchain) - Decentralized, peer-to-peer Git repositories aka "Git meets Bitcoin".
|
||||
* [event-shuttle](https://github.com/sclasen/event-shuttle) - A Unix system service to collect and reliably deliver messages to Kafka.
|
||||
* [ipxed](https://github.com/kelseyhightower/ipxed) - Web interface and api for ipxed.
|
||||
* [BoltStore](https://github.com/yosssi/boltstore) - Session store using Bolt.
|
||||
* [photosite/session](https://godoc.org/bitbucket.org/kardianos/photosite/session) - Sessions for a photo viewing site.
|
||||
* [LedisDB](https://github.com/siddontang/ledisdb) - A high performance NoSQL, using Bolt as optional storage.
|
||||
* [ipLocator](https://github.com/AndreasBriese/ipLocator) - A fast ip-geo-location-server using bolt with bloom filters.
|
||||
* [cayley](https://github.com/google/cayley) - Cayley is an open-source graph database using Bolt as optional backend.
|
||||
* [bleve](http://www.blevesearch.com/) - A pure Go search engine similar to ElasticSearch that uses Bolt as the default storage backend.
|
||||
* [tentacool](https://github.com/optiflows/tentacool) - REST api server to manage system stuff (IP, DNS, Gateway...) on a linux server.
|
||||
* [Seaweed File System](https://github.com/chrislusf/seaweedfs) - Highly scalable distributed key~file system with O(1) disk read.
|
||||
* [InfluxDB](https://influxdata.com) - Scalable datastore for metrics, events, and real-time analytics.
|
||||
* [Freehold](http://tshannon.bitbucket.org/freehold/) - An open, secure, and lightweight platform for your files and data.
|
||||
* [Prometheus Annotation Server](https://github.com/oliver006/prom_annotation_server) - Annotation server for PromDash & Prometheus service monitoring system.
|
||||
* [Consul](https://github.com/hashicorp/consul) - Consul is service discovery and configuration made easy. Distributed, highly available, and datacenter-aware.
|
||||
* [Kala](https://github.com/ajvb/kala) - Kala is a modern job scheduler optimized to run on a single node. It is persistent, JSON over HTTP API, ISO 8601 duration notation, and dependent jobs.
|
||||
* [drive](https://github.com/odeke-em/drive) - drive is an unofficial Google Drive command line client for \*NIX operating systems.
|
||||
* [stow](https://github.com/djherbis/stow) - a persistence manager for objects
|
||||
backed by boltdb.
|
||||
* [buckets](https://github.com/joyrexus/buckets) - a bolt wrapper streamlining
|
||||
simple tx and key scans.
|
||||
* [mbuckets](https://github.com/abhigupta912/mbuckets) - A Bolt wrapper that allows easy operations on multi level (nested) buckets.
|
||||
* [Request Baskets](https://github.com/darklynx/request-baskets) - A web service to collect arbitrary HTTP requests and inspect them via REST API or simple web UI, similar to [RequestBin](http://requestb.in/) service
|
||||
* [Go Report Card](https://goreportcard.com/) - Go code quality report cards as a (free and open source) service.
|
||||
* [Boltdb Boilerplate](https://github.com/bobintornado/boltdb-boilerplate) - Boilerplate wrapper around bolt aiming to make simple calls one-liners.
|
||||
* [lru](https://github.com/crowdriff/lru) - Easy to use Bolt-backed Least-Recently-Used (LRU) read-through cache with chainable remote stores.
|
||||
* [Storm](https://github.com/asdine/storm) - A simple ORM around BoltDB.
|
||||
* [GoWebApp](https://github.com/josephspurrier/gowebapp) - A basic MVC web application in Go using BoltDB.
|
||||
* [SimpleBolt](https://github.com/xyproto/simplebolt) - A simple way to use BoltDB. Deals mainly with strings.
|
||||
* [Algernon](https://github.com/xyproto/algernon) - A HTTP/2 web server with built-in support for Lua. Uses BoltDB as the default database backend.
|
||||
* [MuLiFS](https://github.com/dankomiocevic/mulifs) - Music Library Filesystem creates a filesystem to organise your music files.
|
||||
* [GoShort](https://github.com/pankajkhairnar/goShort) - GoShort is a URL shortener written in Golang and BoltDB for persistent key/value storage and for routing it's using high performent HTTPRouter.
|
||||
|
||||
If you are using Bolt in a project please send a pull request to add it to the list.
|
|
@ -0,0 +1,18 @@
|
|||
version: "{build}"
|
||||
|
||||
os: Windows Server 2012 R2
|
||||
|
||||
clone_folder: c:\gopath\src\github.com\boltdb\bolt
|
||||
|
||||
environment:
|
||||
GOPATH: c:\gopath
|
||||
|
||||
install:
|
||||
- echo %PATH%
|
||||
- echo %GOPATH%
|
||||
- go version
|
||||
- go env
|
||||
- go get -v -t ./...
|
||||
|
||||
build_script:
|
||||
- go test -v ./...
|
|
@ -0,0 +1,7 @@
|
|||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0x7FFFFFFF // 2GB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0xFFFFFFF
|
|
@ -0,0 +1,7 @@
|
|||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -0,0 +1,7 @@
|
|||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0x7FFFFFFF // 2GB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0xFFFFFFF
|
|
@ -0,0 +1,9 @@
|
|||
// +build arm64
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -0,0 +1,10 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// fdatasync flushes written data to a file descriptor.
|
||||
func fdatasync(db *DB) error {
|
||||
return syscall.Fdatasync(int(db.file.Fd()))
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
msAsync = 1 << iota // perform asynchronous writes
|
||||
msSync // perform synchronous writes
|
||||
msInvalidate // invalidate cached data
|
||||
)
|
||||
|
||||
func msync(db *DB) error {
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(unsafe.Pointer(db.data)), uintptr(db.datasz), msInvalidate)
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fdatasync(db *DB) error {
|
||||
if db.data != nil {
|
||||
return msync(db)
|
||||
}
|
||||
return db.file.Sync()
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
// +build ppc
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0x7FFFFFFF // 2GB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0xFFFFFFF
|
|
@ -0,0 +1,9 @@
|
|||
// +build ppc64
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -0,0 +1,9 @@
|
|||
// +build ppc64le
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -0,0 +1,9 @@
|
|||
// +build s390x
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -0,0 +1,89 @@
|
|||
// +build !windows,!plan9,!solaris
|
||||
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// flock acquires an advisory lock on a file descriptor.
|
||||
func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error {
|
||||
var t time.Time
|
||||
for {
|
||||
// If we're beyond our timeout then return an error.
|
||||
// This can only occur after we've attempted a flock once.
|
||||
if t.IsZero() {
|
||||
t = time.Now()
|
||||
} else if timeout > 0 && time.Since(t) > timeout {
|
||||
return ErrTimeout
|
||||
}
|
||||
flag := syscall.LOCK_SH
|
||||
if exclusive {
|
||||
flag = syscall.LOCK_EX
|
||||
}
|
||||
|
||||
// Otherwise attempt to obtain an exclusive lock.
|
||||
err := syscall.Flock(int(db.file.Fd()), flag|syscall.LOCK_NB)
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if err != syscall.EWOULDBLOCK {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for a bit and try again.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// funlock releases an advisory lock on a file descriptor.
|
||||
func funlock(db *DB) error {
|
||||
return syscall.Flock(int(db.file.Fd()), syscall.LOCK_UN)
|
||||
}
|
||||
|
||||
// mmap memory maps a DB's data file.
|
||||
func mmap(db *DB, sz int) error {
|
||||
// Map the data file to memory.
|
||||
b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Advise the kernel that the mmap is accessed randomly.
|
||||
if err := madvise(b, syscall.MADV_RANDOM); err != nil {
|
||||
return fmt.Errorf("madvise: %s", err)
|
||||
}
|
||||
|
||||
// Save the original byte slice and convert to a byte array pointer.
|
||||
db.dataref = b
|
||||
db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0]))
|
||||
db.datasz = sz
|
||||
return nil
|
||||
}
|
||||
|
||||
// munmap unmaps a DB's data file from memory.
|
||||
func munmap(db *DB) error {
|
||||
// Ignore the unmap if we have no mapped data.
|
||||
if db.dataref == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmap using the original byte slice.
|
||||
err := syscall.Munmap(db.dataref)
|
||||
db.dataref = nil
|
||||
db.data = nil
|
||||
db.datasz = 0
|
||||
return err
|
||||
}
|
||||
|
||||
// NOTE: This function is copied from stdlib because it is not available on darwin.
|
||||
func madvise(b []byte, advice int) (err error) {
|
||||
_, _, e1 := syscall.Syscall(syscall.SYS_MADVISE, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(advice))
|
||||
if e1 != 0 {
|
||||
err = e1
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// flock acquires an advisory lock on a file descriptor.
|
||||
func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error {
|
||||
var t time.Time
|
||||
for {
|
||||
// If we're beyond our timeout then return an error.
|
||||
// This can only occur after we've attempted a flock once.
|
||||
if t.IsZero() {
|
||||
t = time.Now()
|
||||
} else if timeout > 0 && time.Since(t) > timeout {
|
||||
return ErrTimeout
|
||||
}
|
||||
var lock syscall.Flock_t
|
||||
lock.Start = 0
|
||||
lock.Len = 0
|
||||
lock.Pid = 0
|
||||
lock.Whence = 0
|
||||
lock.Pid = 0
|
||||
if exclusive {
|
||||
lock.Type = syscall.F_WRLCK
|
||||
} else {
|
||||
lock.Type = syscall.F_RDLCK
|
||||
}
|
||||
err := syscall.FcntlFlock(db.file.Fd(), syscall.F_SETLK, &lock)
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if err != syscall.EAGAIN {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for a bit and try again.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// funlock releases an advisory lock on a file descriptor.
|
||||
func funlock(db *DB) error {
|
||||
var lock syscall.Flock_t
|
||||
lock.Start = 0
|
||||
lock.Len = 0
|
||||
lock.Type = syscall.F_UNLCK
|
||||
lock.Whence = 0
|
||||
return syscall.FcntlFlock(uintptr(db.file.Fd()), syscall.F_SETLK, &lock)
|
||||
}
|
||||
|
||||
// mmap memory maps a DB's data file.
|
||||
func mmap(db *DB, sz int) error {
|
||||
// Map the data file to memory.
|
||||
b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Advise the kernel that the mmap is accessed randomly.
|
||||
if err := unix.Madvise(b, syscall.MADV_RANDOM); err != nil {
|
||||
return fmt.Errorf("madvise: %s", err)
|
||||
}
|
||||
|
||||
// Save the original byte slice and convert to a byte array pointer.
|
||||
db.dataref = b
|
||||
db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0]))
|
||||
db.datasz = sz
|
||||
return nil
|
||||
}
|
||||
|
||||
// munmap unmaps a DB's data file from memory.
|
||||
func munmap(db *DB) error {
|
||||
// Ignore the unmap if we have no mapped data.
|
||||
if db.dataref == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmap using the original byte slice.
|
||||
err := unix.Munmap(db.dataref)
|
||||
db.dataref = nil
|
||||
db.data = nil
|
||||
db.datasz = 0
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// LockFileEx code derived from golang build filemutex_windows.go @ v1.5.1
|
||||
var (
|
||||
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procLockFileEx = modkernel32.NewProc("LockFileEx")
|
||||
procUnlockFileEx = modkernel32.NewProc("UnlockFileEx")
|
||||
)
|
||||
|
||||
const (
|
||||
lockExt = ".lock"
|
||||
|
||||
// see https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx
|
||||
flagLockExclusive = 2
|
||||
flagLockFailImmediately = 1
|
||||
|
||||
// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx
|
||||
errLockViolation syscall.Errno = 0x21
|
||||
)
|
||||
|
||||
func lockFileEx(h syscall.Handle, flags, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) {
|
||||
r, _, err := procLockFileEx.Call(uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)))
|
||||
if r == 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unlockFileEx(h syscall.Handle, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) {
|
||||
r, _, err := procUnlockFileEx.Call(uintptr(h), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)), 0)
|
||||
if r == 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fdatasync flushes written data to a file descriptor.
|
||||
func fdatasync(db *DB) error {
|
||||
return db.file.Sync()
|
||||
}
|
||||
|
||||
// flock acquires an advisory lock on a file descriptor.
|
||||
func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error {
|
||||
// Create a separate lock file on windows because a process
|
||||
// cannot share an exclusive lock on the same file. This is
|
||||
// needed during Tx.WriteTo().
|
||||
f, err := os.OpenFile(db.path+lockExt, os.O_CREATE, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.lockfile = f
|
||||
|
||||
var t time.Time
|
||||
for {
|
||||
// If we're beyond our timeout then return an error.
|
||||
// This can only occur after we've attempted a flock once.
|
||||
if t.IsZero() {
|
||||
t = time.Now()
|
||||
} else if timeout > 0 && time.Since(t) > timeout {
|
||||
return ErrTimeout
|
||||
}
|
||||
|
||||
var flag uint32 = flagLockFailImmediately
|
||||
if exclusive {
|
||||
flag |= flagLockExclusive
|
||||
}
|
||||
|
||||
err := lockFileEx(syscall.Handle(db.lockfile.Fd()), flag, 0, 1, 0, &syscall.Overlapped{})
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if err != errLockViolation {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for a bit and try again.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// funlock releases an advisory lock on a file descriptor.
|
||||
func funlock(db *DB) error {
|
||||
err := unlockFileEx(syscall.Handle(db.lockfile.Fd()), 0, 1, 0, &syscall.Overlapped{})
|
||||
db.lockfile.Close()
|
||||
os.Remove(db.path+lockExt)
|
||||
return err
|
||||
}
|
||||
|
||||
// mmap memory maps a DB's data file.
|
||||
// Based on: https://github.com/edsrzf/mmap-go
|
||||
func mmap(db *DB, sz int) error {
|
||||
if !db.readOnly {
|
||||
// Truncate the database to the size of the mmap.
|
||||
if err := db.file.Truncate(int64(sz)); err != nil {
|
||||
return fmt.Errorf("truncate: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Open a file mapping handle.
|
||||
sizelo := uint32(sz >> 32)
|
||||
sizehi := uint32(sz) & 0xffffffff
|
||||
h, errno := syscall.CreateFileMapping(syscall.Handle(db.file.Fd()), nil, syscall.PAGE_READONLY, sizelo, sizehi, nil)
|
||||
if h == 0 {
|
||||
return os.NewSyscallError("CreateFileMapping", errno)
|
||||
}
|
||||
|
||||
// Create the memory map.
|
||||
addr, errno := syscall.MapViewOfFile(h, syscall.FILE_MAP_READ, 0, 0, uintptr(sz))
|
||||
if addr == 0 {
|
||||
return os.NewSyscallError("MapViewOfFile", errno)
|
||||
}
|
||||
|
||||
// Close mapping handle.
|
||||
if err := syscall.CloseHandle(syscall.Handle(h)); err != nil {
|
||||
return os.NewSyscallError("CloseHandle", err)
|
||||
}
|
||||
|
||||
// Convert to a byte array.
|
||||
db.data = ((*[maxMapSize]byte)(unsafe.Pointer(addr)))
|
||||
db.datasz = sz
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// munmap unmaps a pointer from a file.
|
||||
// Based on: https://github.com/edsrzf/mmap-go
|
||||
func munmap(db *DB) error {
|
||||
if db.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
addr := (uintptr)(unsafe.Pointer(&db.data[0]))
|
||||
if err := syscall.UnmapViewOfFile(addr); err != nil {
|
||||
return os.NewSyscallError("UnmapViewOfFile", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
// +build !windows,!plan9,!linux,!openbsd
|
||||
|
||||
package bolt
|
||||
|
||||
// fdatasync flushes written data to a file descriptor.
|
||||
func fdatasync(db *DB) error {
|
||||
return db.file.Sync()
|
||||
}
|
|
@ -0,0 +1,748 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxKeySize is the maximum length of a key, in bytes.
|
||||
MaxKeySize = 32768
|
||||
|
||||
// MaxValueSize is the maximum length of a value, in bytes.
|
||||
MaxValueSize = (1 << 31) - 2
|
||||
)
|
||||
|
||||
const (
|
||||
maxUint = ^uint(0)
|
||||
minUint = 0
|
||||
maxInt = int(^uint(0) >> 1)
|
||||
minInt = -maxInt - 1
|
||||
)
|
||||
|
||||
const bucketHeaderSize = int(unsafe.Sizeof(bucket{}))
|
||||
|
||||
const (
|
||||
minFillPercent = 0.1
|
||||
maxFillPercent = 1.0
|
||||
)
|
||||
|
||||
// DefaultFillPercent is the percentage that split pages are filled.
|
||||
// This value can be changed by setting Bucket.FillPercent.
|
||||
const DefaultFillPercent = 0.5
|
||||
|
||||
// Bucket represents a collection of key/value pairs inside the database.
|
||||
type Bucket struct {
|
||||
*bucket
|
||||
tx *Tx // the associated transaction
|
||||
buckets map[string]*Bucket // subbucket cache
|
||||
page *page // inline page reference
|
||||
rootNode *node // materialized node for the root page.
|
||||
nodes map[pgid]*node // node cache
|
||||
|
||||
// Sets the threshold for filling nodes when they split. By default,
|
||||
// the bucket will fill to 50% but it can be useful to increase this
|
||||
// amount if you know that your write workloads are mostly append-only.
|
||||
//
|
||||
// This is non-persisted across transactions so it must be set in every Tx.
|
||||
FillPercent float64
|
||||
}
|
||||
|
||||
// bucket represents the on-file representation of a bucket.
|
||||
// This is stored as the "value" of a bucket key. If the bucket is small enough,
|
||||
// then its root page can be stored inline in the "value", after the bucket
|
||||
// header. In the case of inline buckets, the "root" will be 0.
|
||||
type bucket struct {
|
||||
root pgid // page id of the bucket's root-level page
|
||||
sequence uint64 // monotonically incrementing, used by NextSequence()
|
||||
}
|
||||
|
||||
// newBucket returns a new bucket associated with a transaction.
|
||||
func newBucket(tx *Tx) Bucket {
|
||||
var b = Bucket{tx: tx, FillPercent: DefaultFillPercent}
|
||||
if tx.writable {
|
||||
b.buckets = make(map[string]*Bucket)
|
||||
b.nodes = make(map[pgid]*node)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Tx returns the tx of the bucket.
|
||||
func (b *Bucket) Tx() *Tx {
|
||||
return b.tx
|
||||
}
|
||||
|
||||
// Root returns the root of the bucket.
|
||||
func (b *Bucket) Root() pgid {
|
||||
return b.root
|
||||
}
|
||||
|
||||
// Writable returns whether the bucket is writable.
|
||||
func (b *Bucket) Writable() bool {
|
||||
return b.tx.writable
|
||||
}
|
||||
|
||||
// Cursor creates a cursor associated with the bucket.
|
||||
// The cursor is only valid as long as the transaction is open.
|
||||
// Do not use a cursor after the transaction is closed.
|
||||
func (b *Bucket) Cursor() *Cursor {
|
||||
// Update transaction statistics.
|
||||
b.tx.stats.CursorCount++
|
||||
|
||||
// Allocate and return a cursor.
|
||||
return &Cursor{
|
||||
bucket: b,
|
||||
stack: make([]elemRef, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Bucket retrieves a nested bucket by name.
|
||||
// Returns nil if the bucket does not exist.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (b *Bucket) Bucket(name []byte) *Bucket {
|
||||
if b.buckets != nil {
|
||||
if child := b.buckets[string(name)]; child != nil {
|
||||
return child
|
||||
}
|
||||
}
|
||||
|
||||
// Move cursor to key.
|
||||
c := b.Cursor()
|
||||
k, v, flags := c.seek(name)
|
||||
|
||||
// Return nil if the key doesn't exist or it is not a bucket.
|
||||
if !bytes.Equal(name, k) || (flags&bucketLeafFlag) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise create a bucket and cache it.
|
||||
var child = b.openBucket(v)
|
||||
if b.buckets != nil {
|
||||
b.buckets[string(name)] = child
|
||||
}
|
||||
|
||||
return child
|
||||
}
|
||||
|
||||
// Helper method that re-interprets a sub-bucket value
|
||||
// from a parent into a Bucket
|
||||
func (b *Bucket) openBucket(value []byte) *Bucket {
|
||||
var child = newBucket(b.tx)
|
||||
|
||||
// If this is a writable transaction then we need to copy the bucket entry.
|
||||
// Read-only transactions can point directly at the mmap entry.
|
||||
if b.tx.writable {
|
||||
child.bucket = &bucket{}
|
||||
*child.bucket = *(*bucket)(unsafe.Pointer(&value[0]))
|
||||
} else {
|
||||
child.bucket = (*bucket)(unsafe.Pointer(&value[0]))
|
||||
}
|
||||
|
||||
// Save a reference to the inline page if the bucket is inline.
|
||||
if child.root == 0 {
|
||||
child.page = (*page)(unsafe.Pointer(&value[bucketHeaderSize]))
|
||||
}
|
||||
|
||||
return &child
|
||||
}
|
||||
|
||||
// CreateBucket creates a new bucket at the given key and returns the new bucket.
|
||||
// Returns an error if the key already exists, if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) {
|
||||
if b.tx.db == nil {
|
||||
return nil, ErrTxClosed
|
||||
} else if !b.tx.writable {
|
||||
return nil, ErrTxNotWritable
|
||||
} else if len(key) == 0 {
|
||||
return nil, ErrBucketNameRequired
|
||||
}
|
||||
|
||||
// Move cursor to correct position.
|
||||
c := b.Cursor()
|
||||
k, _, flags := c.seek(key)
|
||||
|
||||
// Return an error if there is an existing key.
|
||||
if bytes.Equal(key, k) {
|
||||
if (flags & bucketLeafFlag) != 0 {
|
||||
return nil, ErrBucketExists
|
||||
} else {
|
||||
return nil, ErrIncompatibleValue
|
||||
}
|
||||
}
|
||||
|
||||
// Create empty, inline bucket.
|
||||
var bucket = Bucket{
|
||||
bucket: &bucket{},
|
||||
rootNode: &node{isLeaf: true},
|
||||
FillPercent: DefaultFillPercent,
|
||||
}
|
||||
var value = bucket.write()
|
||||
|
||||
// Insert into node.
|
||||
key = cloneBytes(key)
|
||||
c.node().put(key, key, value, 0, bucketLeafFlag)
|
||||
|
||||
// Since subbuckets are not allowed on inline buckets, we need to
|
||||
// dereference the inline page, if it exists. This will cause the bucket
|
||||
// to be treated as a regular, non-inline bucket for the rest of the tx.
|
||||
b.page = nil
|
||||
|
||||
return b.Bucket(key), nil
|
||||
}
|
||||
|
||||
// CreateBucketIfNotExists creates a new bucket if it doesn't already exist and returns a reference to it.
|
||||
// Returns an error if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (b *Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error) {
|
||||
child, err := b.CreateBucket(key)
|
||||
if err == ErrBucketExists {
|
||||
return b.Bucket(key), nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return child, nil
|
||||
}
|
||||
|
||||
// DeleteBucket deletes a bucket at the given key.
|
||||
// Returns an error if the bucket does not exists, or if the key represents a non-bucket value.
|
||||
func (b *Bucket) DeleteBucket(key []byte) error {
|
||||
if b.tx.db == nil {
|
||||
return ErrTxClosed
|
||||
} else if !b.Writable() {
|
||||
return ErrTxNotWritable
|
||||
}
|
||||
|
||||
// Move cursor to correct position.
|
||||
c := b.Cursor()
|
||||
k, _, flags := c.seek(key)
|
||||
|
||||
// Return an error if bucket doesn't exist or is not a bucket.
|
||||
if !bytes.Equal(key, k) {
|
||||
return ErrBucketNotFound
|
||||
} else if (flags & bucketLeafFlag) == 0 {
|
||||
return ErrIncompatibleValue
|
||||
}
|
||||
|
||||
// Recursively delete all child buckets.
|
||||
child := b.Bucket(key)
|
||||
err := child.ForEach(func(k, v []byte) error {
|
||||
if v == nil {
|
||||
if err := child.DeleteBucket(k); err != nil {
|
||||
return fmt.Errorf("delete bucket: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove cached copy.
|
||||
delete(b.buckets, string(key))
|
||||
|
||||
// Release all bucket pages to freelist.
|
||||
child.nodes = nil
|
||||
child.rootNode = nil
|
||||
child.free()
|
||||
|
||||
// Delete the node if we have a matching key.
|
||||
c.node().del(key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves the value for a key in the bucket.
|
||||
// Returns a nil value if the key does not exist or if the key is a nested bucket.
|
||||
// The returned value is only valid for the life of the transaction.
|
||||
func (b *Bucket) Get(key []byte) []byte {
|
||||
k, v, flags := b.Cursor().seek(key)
|
||||
|
||||
// Return nil if this is a bucket.
|
||||
if (flags & bucketLeafFlag) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If our target node isn't the same key as what's passed in then return nil.
|
||||
if !bytes.Equal(key, k) {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Put sets the value for a key in the bucket.
|
||||
// If the key exist then its previous value will be overwritten.
|
||||
// Supplied value must remain valid for the life of the transaction.
|
||||
// Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large.
|
||||
func (b *Bucket) Put(key []byte, value []byte) error {
|
||||
if b.tx.db == nil {
|
||||
return ErrTxClosed
|
||||
} else if !b.Writable() {
|
||||
return ErrTxNotWritable
|
||||
} else if len(key) == 0 {
|
||||
return ErrKeyRequired
|
||||
} else if len(key) > MaxKeySize {
|
||||
return ErrKeyTooLarge
|
||||
} else if int64(len(value)) > MaxValueSize {
|
||||
return ErrValueTooLarge
|
||||
}
|
||||
|
||||
// Move cursor to correct position.
|
||||
c := b.Cursor()
|
||||
k, _, flags := c.seek(key)
|
||||
|
||||
// Return an error if there is an existing key with a bucket value.
|
||||
if bytes.Equal(key, k) && (flags&bucketLeafFlag) != 0 {
|
||||
return ErrIncompatibleValue
|
||||
}
|
||||
|
||||
// Insert into node.
|
||||
key = cloneBytes(key)
|
||||
c.node().put(key, key, value, 0, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the bucket.
|
||||
// If the key does not exist then nothing is done and a nil error is returned.
|
||||
// Returns an error if the bucket was created from a read-only transaction.
|
||||
func (b *Bucket) Delete(key []byte) error {
|
||||
if b.tx.db == nil {
|
||||
return ErrTxClosed
|
||||
} else if !b.Writable() {
|
||||
return ErrTxNotWritable
|
||||
}
|
||||
|
||||
// Move cursor to correct position.
|
||||
c := b.Cursor()
|
||||
_, _, flags := c.seek(key)
|
||||
|
||||
// Return an error if there is already existing bucket value.
|
||||
if (flags & bucketLeafFlag) != 0 {
|
||||
return ErrIncompatibleValue
|
||||
}
|
||||
|
||||
// Delete the node if we have a matching key.
|
||||
c.node().del(key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextSequence returns an autoincrementing integer for the bucket.
|
||||
func (b *Bucket) NextSequence() (uint64, error) {
|
||||
if b.tx.db == nil {
|
||||
return 0, ErrTxClosed
|
||||
} else if !b.Writable() {
|
||||
return 0, ErrTxNotWritable
|
||||
}
|
||||
|
||||
// Materialize the root node if it hasn't been already so that the
|
||||
// bucket will be saved during commit.
|
||||
if b.rootNode == nil {
|
||||
_ = b.node(b.root, nil)
|
||||
}
|
||||
|
||||
// Increment and return the sequence.
|
||||
b.bucket.sequence++
|
||||
return b.bucket.sequence, nil
|
||||
}
|
||||
|
||||
// ForEach executes a function for each key/value pair in a bucket.
|
||||
// If the provided function returns an error then the iteration is stopped and
|
||||
// the error is returned to the caller. The provided function must not modify
|
||||
// the bucket; this will result in undefined behavior.
|
||||
func (b *Bucket) ForEach(fn func(k, v []byte) error) error {
|
||||
if b.tx.db == nil {
|
||||
return ErrTxClosed
|
||||
}
|
||||
c := b.Cursor()
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
if err := fn(k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stat returns stats on a bucket.
|
||||
func (b *Bucket) Stats() BucketStats {
|
||||
var s, subStats BucketStats
|
||||
pageSize := b.tx.db.pageSize
|
||||
s.BucketN += 1
|
||||
if b.root == 0 {
|
||||
s.InlineBucketN += 1
|
||||
}
|
||||
b.forEachPage(func(p *page, depth int) {
|
||||
if (p.flags & leafPageFlag) != 0 {
|
||||
s.KeyN += int(p.count)
|
||||
|
||||
// used totals the used bytes for the page
|
||||
used := pageHeaderSize
|
||||
|
||||
if p.count != 0 {
|
||||
// If page has any elements, add all element headers.
|
||||
used += leafPageElementSize * int(p.count-1)
|
||||
|
||||
// Add all element key, value sizes.
|
||||
// The computation takes advantage of the fact that the position
|
||||
// of the last element's key/value equals to the total of the sizes
|
||||
// of all previous elements' keys and values.
|
||||
// It also includes the last element's header.
|
||||
lastElement := p.leafPageElement(p.count - 1)
|
||||
used += int(lastElement.pos + lastElement.ksize + lastElement.vsize)
|
||||
}
|
||||
|
||||
if b.root == 0 {
|
||||
// For inlined bucket just update the inline stats
|
||||
s.InlineBucketInuse += used
|
||||
} else {
|
||||
// For non-inlined bucket update all the leaf stats
|
||||
s.LeafPageN++
|
||||
s.LeafInuse += used
|
||||
s.LeafOverflowN += int(p.overflow)
|
||||
|
||||
// Collect stats from sub-buckets.
|
||||
// Do that by iterating over all element headers
|
||||
// looking for the ones with the bucketLeafFlag.
|
||||
for i := uint16(0); i < p.count; i++ {
|
||||
e := p.leafPageElement(i)
|
||||
if (e.flags & bucketLeafFlag) != 0 {
|
||||
// For any bucket element, open the element value
|
||||
// and recursively call Stats on the contained bucket.
|
||||
subStats.Add(b.openBucket(e.value()).Stats())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (p.flags & branchPageFlag) != 0 {
|
||||
s.BranchPageN++
|
||||
lastElement := p.branchPageElement(p.count - 1)
|
||||
|
||||
// used totals the used bytes for the page
|
||||
// Add header and all element headers.
|
||||
used := pageHeaderSize + (branchPageElementSize * int(p.count-1))
|
||||
|
||||
// Add size of all keys and values.
|
||||
// Again, use the fact that last element's position equals to
|
||||
// the total of key, value sizes of all previous elements.
|
||||
used += int(lastElement.pos + lastElement.ksize)
|
||||
s.BranchInuse += used
|
||||
s.BranchOverflowN += int(p.overflow)
|
||||
}
|
||||
|
||||
// Keep track of maximum page depth.
|
||||
if depth+1 > s.Depth {
|
||||
s.Depth = (depth + 1)
|
||||
}
|
||||
})
|
||||
|
||||
// Alloc stats can be computed from page counts and pageSize.
|
||||
s.BranchAlloc = (s.BranchPageN + s.BranchOverflowN) * pageSize
|
||||
s.LeafAlloc = (s.LeafPageN + s.LeafOverflowN) * pageSize
|
||||
|
||||
// Add the max depth of sub-buckets to get total nested depth.
|
||||
s.Depth += subStats.Depth
|
||||
// Add the stats for all sub-buckets
|
||||
s.Add(subStats)
|
||||
return s
|
||||
}
|
||||
|
||||
// forEachPage iterates over every page in a bucket, including inline pages.
|
||||
func (b *Bucket) forEachPage(fn func(*page, int)) {
|
||||
// If we have an inline page then just use that.
|
||||
if b.page != nil {
|
||||
fn(b.page, 0)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise traverse the page hierarchy.
|
||||
b.tx.forEachPage(b.root, 0, fn)
|
||||
}
|
||||
|
||||
// forEachPageNode iterates over every page (or node) in a bucket.
|
||||
// This also includes inline pages.
|
||||
func (b *Bucket) forEachPageNode(fn func(*page, *node, int)) {
|
||||
// If we have an inline page or root node then just use that.
|
||||
if b.page != nil {
|
||||
fn(b.page, nil, 0)
|
||||
return
|
||||
}
|
||||
b._forEachPageNode(b.root, 0, fn)
|
||||
}
|
||||
|
||||
func (b *Bucket) _forEachPageNode(pgid pgid, depth int, fn func(*page, *node, int)) {
|
||||
var p, n = b.pageNode(pgid)
|
||||
|
||||
// Execute function.
|
||||
fn(p, n, depth)
|
||||
|
||||
// Recursively loop over children.
|
||||
if p != nil {
|
||||
if (p.flags & branchPageFlag) != 0 {
|
||||
for i := 0; i < int(p.count); i++ {
|
||||
elem := p.branchPageElement(uint16(i))
|
||||
b._forEachPageNode(elem.pgid, depth+1, fn)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !n.isLeaf {
|
||||
for _, inode := range n.inodes {
|
||||
b._forEachPageNode(inode.pgid, depth+1, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// spill writes all the nodes for this bucket to dirty pages.
|
||||
func (b *Bucket) spill() error {
|
||||
// Spill all child buckets first.
|
||||
for name, child := range b.buckets {
|
||||
// If the child bucket is small enough and it has no child buckets then
|
||||
// write it inline into the parent bucket's page. Otherwise spill it
|
||||
// like a normal bucket and make the parent value a pointer to the page.
|
||||
var value []byte
|
||||
if child.inlineable() {
|
||||
child.free()
|
||||
value = child.write()
|
||||
} else {
|
||||
if err := child.spill(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the child bucket header in this bucket.
|
||||
value = make([]byte, unsafe.Sizeof(bucket{}))
|
||||
var bucket = (*bucket)(unsafe.Pointer(&value[0]))
|
||||
*bucket = *child.bucket
|
||||
}
|
||||
|
||||
// Skip writing the bucket if there are no materialized nodes.
|
||||
if child.rootNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update parent node.
|
||||
var c = b.Cursor()
|
||||
k, _, flags := c.seek([]byte(name))
|
||||
if !bytes.Equal([]byte(name), k) {
|
||||
panic(fmt.Sprintf("misplaced bucket header: %x -> %x", []byte(name), k))
|
||||
}
|
||||
if flags&bucketLeafFlag == 0 {
|
||||
panic(fmt.Sprintf("unexpected bucket header flag: %x", flags))
|
||||
}
|
||||
c.node().put([]byte(name), []byte(name), value, 0, bucketLeafFlag)
|
||||
}
|
||||
|
||||
// Ignore if there's not a materialized root node.
|
||||
if b.rootNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Spill nodes.
|
||||
if err := b.rootNode.spill(); err != nil {
|
||||
return err
|
||||
}
|
||||
b.rootNode = b.rootNode.root()
|
||||
|
||||
// Update the root node for this bucket.
|
||||
if b.rootNode.pgid >= b.tx.meta.pgid {
|
||||
panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", b.rootNode.pgid, b.tx.meta.pgid))
|
||||
}
|
||||
b.root = b.rootNode.pgid
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// inlineable returns true if a bucket is small enough to be written inline
|
||||
// and if it contains no subbuckets. Otherwise returns false.
|
||||
func (b *Bucket) inlineable() bool {
|
||||
var n = b.rootNode
|
||||
|
||||
// Bucket must only contain a single leaf node.
|
||||
if n == nil || !n.isLeaf {
|
||||
return false
|
||||
}
|
||||
|
||||
// Bucket is not inlineable if it contains subbuckets or if it goes beyond
|
||||
// our threshold for inline bucket size.
|
||||
var size = pageHeaderSize
|
||||
for _, inode := range n.inodes {
|
||||
size += leafPageElementSize + len(inode.key) + len(inode.value)
|
||||
|
||||
if inode.flags&bucketLeafFlag != 0 {
|
||||
return false
|
||||
} else if size > b.maxInlineBucketSize() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Returns the maximum total size of a bucket to make it a candidate for inlining.
|
||||
func (b *Bucket) maxInlineBucketSize() int {
|
||||
return b.tx.db.pageSize / 4
|
||||
}
|
||||
|
||||
// write allocates and writes a bucket to a byte slice.
|
||||
func (b *Bucket) write() []byte {
|
||||
// Allocate the appropriate size.
|
||||
var n = b.rootNode
|
||||
var value = make([]byte, bucketHeaderSize+n.size())
|
||||
|
||||
// Write a bucket header.
|
||||
var bucket = (*bucket)(unsafe.Pointer(&value[0]))
|
||||
*bucket = *b.bucket
|
||||
|
||||
// Convert byte slice to a fake page and write the root node.
|
||||
var p = (*page)(unsafe.Pointer(&value[bucketHeaderSize]))
|
||||
n.write(p)
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// rebalance attempts to balance all nodes.
|
||||
func (b *Bucket) rebalance() {
|
||||
for _, n := range b.nodes {
|
||||
n.rebalance()
|
||||
}
|
||||
for _, child := range b.buckets {
|
||||
child.rebalance()
|
||||
}
|
||||
}
|
||||
|
||||
// node creates a node from a page and associates it with a given parent.
|
||||
func (b *Bucket) node(pgid pgid, parent *node) *node {
|
||||
_assert(b.nodes != nil, "nodes map expected")
|
||||
|
||||
// Retrieve node if it's already been created.
|
||||
if n := b.nodes[pgid]; n != nil {
|
||||
return n
|
||||
}
|
||||
|
||||
// Otherwise create a node and cache it.
|
||||
n := &node{bucket: b, parent: parent}
|
||||
if parent == nil {
|
||||
b.rootNode = n
|
||||
} else {
|
||||
parent.children = append(parent.children, n)
|
||||
}
|
||||
|
||||
// Use the inline page if this is an inline bucket.
|
||||
var p = b.page
|
||||
if p == nil {
|
||||
p = b.tx.page(pgid)
|
||||
}
|
||||
|
||||
// Read the page into the node and cache it.
|
||||
n.read(p)
|
||||
b.nodes[pgid] = n
|
||||
|
||||
// Update statistics.
|
||||
b.tx.stats.NodeCount++
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// free recursively frees all pages in the bucket.
|
||||
func (b *Bucket) free() {
|
||||
if b.root == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var tx = b.tx
|
||||
b.forEachPageNode(func(p *page, n *node, _ int) {
|
||||
if p != nil {
|
||||
tx.db.freelist.free(tx.meta.txid, p)
|
||||
} else {
|
||||
n.free()
|
||||
}
|
||||
})
|
||||
b.root = 0
|
||||
}
|
||||
|
||||
// dereference removes all references to the old mmap.
|
||||
func (b *Bucket) dereference() {
|
||||
if b.rootNode != nil {
|
||||
b.rootNode.root().dereference()
|
||||
}
|
||||
|
||||
for _, child := range b.buckets {
|
||||
child.dereference()
|
||||
}
|
||||
}
|
||||
|
||||
// pageNode returns the in-memory node, if it exists.
|
||||
// Otherwise returns the underlying page.
|
||||
func (b *Bucket) pageNode(id pgid) (*page, *node) {
|
||||
// Inline buckets have a fake page embedded in their value so treat them
|
||||
// differently. We'll return the rootNode (if available) or the fake page.
|
||||
if b.root == 0 {
|
||||
if id != 0 {
|
||||
panic(fmt.Sprintf("inline bucket non-zero page access(2): %d != 0", id))
|
||||
}
|
||||
if b.rootNode != nil {
|
||||
return nil, b.rootNode
|
||||
}
|
||||
return b.page, nil
|
||||
}
|
||||
|
||||
// Check the node cache for non-inline buckets.
|
||||
if b.nodes != nil {
|
||||
if n := b.nodes[id]; n != nil {
|
||||
return nil, n
|
||||
}
|
||||
}
|
||||
|
||||
// Finally lookup the page from the transaction if no node is materialized.
|
||||
return b.tx.page(id), nil
|
||||
}
|
||||
|
||||
// BucketStats records statistics about resources used by a bucket.
|
||||
type BucketStats struct {
|
||||
// Page count statistics.
|
||||
BranchPageN int // number of logical branch pages
|
||||
BranchOverflowN int // number of physical branch overflow pages
|
||||
LeafPageN int // number of logical leaf pages
|
||||
LeafOverflowN int // number of physical leaf overflow pages
|
||||
|
||||
// Tree statistics.
|
||||
KeyN int // number of keys/value pairs
|
||||
Depth int // number of levels in B+tree
|
||||
|
||||
// Page size utilization.
|
||||
BranchAlloc int // bytes allocated for physical branch pages
|
||||
BranchInuse int // bytes actually used for branch data
|
||||
LeafAlloc int // bytes allocated for physical leaf pages
|
||||
LeafInuse int // bytes actually used for leaf data
|
||||
|
||||
// Bucket statistics
|
||||
BucketN int // total number of buckets including the top bucket
|
||||
InlineBucketN int // total number on inlined buckets
|
||||
InlineBucketInuse int // bytes used for inlined buckets (also accounted for in LeafInuse)
|
||||
}
|
||||
|
||||
func (s *BucketStats) Add(other BucketStats) {
|
||||
s.BranchPageN += other.BranchPageN
|
||||
s.BranchOverflowN += other.BranchOverflowN
|
||||
s.LeafPageN += other.LeafPageN
|
||||
s.LeafOverflowN += other.LeafOverflowN
|
||||
s.KeyN += other.KeyN
|
||||
if s.Depth < other.Depth {
|
||||
s.Depth = other.Depth
|
||||
}
|
||||
s.BranchAlloc += other.BranchAlloc
|
||||
s.BranchInuse += other.BranchInuse
|
||||
s.LeafAlloc += other.LeafAlloc
|
||||
s.LeafInuse += other.LeafInuse
|
||||
|
||||
s.BucketN += other.BucketN
|
||||
s.InlineBucketN += other.InlineBucketN
|
||||
s.InlineBucketInuse += other.InlineBucketInuse
|
||||
}
|
||||
|
||||
// cloneBytes returns a copy of a given slice.
|
||||
func cloneBytes(v []byte) []byte {
|
||||
var clone = make([]byte, len(v))
|
||||
copy(clone, v)
|
||||
return clone
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,400 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Cursor represents an iterator that can traverse over all key/value pairs in a bucket in sorted order.
|
||||
// Cursors see nested buckets with value == nil.
|
||||
// Cursors can be obtained from a transaction and are valid as long as the transaction is open.
|
||||
//
|
||||
// Keys and values returned from the cursor are only valid for the life of the transaction.
|
||||
//
|
||||
// Changing data while traversing with a cursor may cause it to be invalidated
|
||||
// and return unexpected keys and/or values. You must reposition your cursor
|
||||
// after mutating data.
|
||||
type Cursor struct {
|
||||
bucket *Bucket
|
||||
stack []elemRef
|
||||
}
|
||||
|
||||
// Bucket returns the bucket that this cursor was created from.
|
||||
func (c *Cursor) Bucket() *Bucket {
|
||||
return c.bucket
|
||||
}
|
||||
|
||||
// First moves the cursor to the first item in the bucket and returns its key and value.
|
||||
// If the bucket is empty then a nil key and value are returned.
|
||||
// The returned key and value are only valid for the life of the transaction.
|
||||
func (c *Cursor) First() (key []byte, value []byte) {
|
||||
_assert(c.bucket.tx.db != nil, "tx closed")
|
||||
c.stack = c.stack[:0]
|
||||
p, n := c.bucket.pageNode(c.bucket.root)
|
||||
c.stack = append(c.stack, elemRef{page: p, node: n, index: 0})
|
||||
c.first()
|
||||
|
||||
// If we land on an empty page then move to the next value.
|
||||
// https://github.com/boltdb/bolt/issues/450
|
||||
if c.stack[len(c.stack)-1].count() == 0 {
|
||||
c.next()
|
||||
}
|
||||
|
||||
k, v, flags := c.keyValue()
|
||||
if (flags & uint32(bucketLeafFlag)) != 0 {
|
||||
return k, nil
|
||||
}
|
||||
return k, v
|
||||
|
||||
}
|
||||
|
||||
// Last moves the cursor to the last item in the bucket and returns its key and value.
|
||||
// If the bucket is empty then a nil key and value are returned.
|
||||
// The returned key and value are only valid for the life of the transaction.
|
||||
func (c *Cursor) Last() (key []byte, value []byte) {
|
||||
_assert(c.bucket.tx.db != nil, "tx closed")
|
||||
c.stack = c.stack[:0]
|
||||
p, n := c.bucket.pageNode(c.bucket.root)
|
||||
ref := elemRef{page: p, node: n}
|
||||
ref.index = ref.count() - 1
|
||||
c.stack = append(c.stack, ref)
|
||||
c.last()
|
||||
k, v, flags := c.keyValue()
|
||||
if (flags & uint32(bucketLeafFlag)) != 0 {
|
||||
return k, nil
|
||||
}
|
||||
return k, v
|
||||
}
|
||||
|
||||
// Next moves the cursor to the next item in the bucket and returns its key and value.
|
||||
// If the cursor is at the end of the bucket then a nil key and value are returned.
|
||||
// The returned key and value are only valid for the life of the transaction.
|
||||
func (c *Cursor) Next() (key []byte, value []byte) {
|
||||
_assert(c.bucket.tx.db != nil, "tx closed")
|
||||
k, v, flags := c.next()
|
||||
if (flags & uint32(bucketLeafFlag)) != 0 {
|
||||
return k, nil
|
||||
}
|
||||
return k, v
|
||||
}
|
||||
|
||||
// Prev moves the cursor to the previous item in the bucket and returns its key and value.
|
||||
// If the cursor is at the beginning of the bucket then a nil key and value are returned.
|
||||
// The returned key and value are only valid for the life of the transaction.
|
||||
func (c *Cursor) Prev() (key []byte, value []byte) {
|
||||
_assert(c.bucket.tx.db != nil, "tx closed")
|
||||
|
||||
// Attempt to move back one element until we're successful.
|
||||
// Move up the stack as we hit the beginning of each page in our stack.
|
||||
for i := len(c.stack) - 1; i >= 0; i-- {
|
||||
elem := &c.stack[i]
|
||||
if elem.index > 0 {
|
||||
elem.index--
|
||||
break
|
||||
}
|
||||
c.stack = c.stack[:i]
|
||||
}
|
||||
|
||||
// If we've hit the end then return nil.
|
||||
if len(c.stack) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Move down the stack to find the last element of the last leaf under this branch.
|
||||
c.last()
|
||||
k, v, flags := c.keyValue()
|
||||
if (flags & uint32(bucketLeafFlag)) != 0 {
|
||||
return k, nil
|
||||
}
|
||||
return k, v
|
||||
}
|
||||
|
||||
// Seek moves the cursor to a given key and returns it.
|
||||
// If the key does not exist then the next key is used. If no keys
|
||||
// follow, a nil key is returned.
|
||||
// The returned key and value are only valid for the life of the transaction.
|
||||
func (c *Cursor) Seek(seek []byte) (key []byte, value []byte) {
|
||||
k, v, flags := c.seek(seek)
|
||||
|
||||
// If we ended up after the last element of a page then move to the next one.
|
||||
if ref := &c.stack[len(c.stack)-1]; ref.index >= ref.count() {
|
||||
k, v, flags = c.next()
|
||||
}
|
||||
|
||||
if k == nil {
|
||||
return nil, nil
|
||||
} else if (flags & uint32(bucketLeafFlag)) != 0 {
|
||||
return k, nil
|
||||
}
|
||||
return k, v
|
||||
}
|
||||
|
||||
// Delete removes the current key/value under the cursor from the bucket.
|
||||
// Delete fails if current key/value is a bucket or if the transaction is not writable.
|
||||
func (c *Cursor) Delete() error {
|
||||
if c.bucket.tx.db == nil {
|
||||
return ErrTxClosed
|
||||
} else if !c.bucket.Writable() {
|
||||
return ErrTxNotWritable
|
||||
}
|
||||
|
||||
key, _, flags := c.keyValue()
|
||||
// Return an error if current value is a bucket.
|
||||
if (flags & bucketLeafFlag) != 0 {
|
||||
return ErrIncompatibleValue
|
||||
}
|
||||
c.node().del(key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// seek moves the cursor to a given key and returns it.
|
||||
// If the key does not exist then the next key is used.
|
||||
func (c *Cursor) seek(seek []byte) (key []byte, value []byte, flags uint32) {
|
||||
_assert(c.bucket.tx.db != nil, "tx closed")
|
||||
|
||||
// Start from root page/node and traverse to correct page.
|
||||
c.stack = c.stack[:0]
|
||||
c.search(seek, c.bucket.root)
|
||||
ref := &c.stack[len(c.stack)-1]
|
||||
|
||||
// If the cursor is pointing to the end of page/node then return nil.
|
||||
if ref.index >= ref.count() {
|
||||
return nil, nil, 0
|
||||
}
|
||||
|
||||
// If this is a bucket then return a nil value.
|
||||
return c.keyValue()
|
||||
}
|
||||
|
||||
// first moves the cursor to the first leaf element under the last page in the stack.
|
||||
func (c *Cursor) first() {
|
||||
for {
|
||||
// Exit when we hit a leaf page.
|
||||
var ref = &c.stack[len(c.stack)-1]
|
||||
if ref.isLeaf() {
|
||||
break
|
||||
}
|
||||
|
||||
// Keep adding pages pointing to the first element to the stack.
|
||||
var pgid pgid
|
||||
if ref.node != nil {
|
||||
pgid = ref.node.inodes[ref.index].pgid
|
||||
} else {
|
||||
pgid = ref.page.branchPageElement(uint16(ref.index)).pgid
|
||||
}
|
||||
p, n := c.bucket.pageNode(pgid)
|
||||
c.stack = append(c.stack, elemRef{page: p, node: n, index: 0})
|
||||
}
|
||||
}
|
||||
|
||||
// last moves the cursor to the last leaf element under the last page in the stack.
|
||||
func (c *Cursor) last() {
|
||||
for {
|
||||
// Exit when we hit a leaf page.
|
||||
ref := &c.stack[len(c.stack)-1]
|
||||
if ref.isLeaf() {
|
||||
break
|
||||
}
|
||||
|
||||
// Keep adding pages pointing to the last element in the stack.
|
||||
var pgid pgid
|
||||
if ref.node != nil {
|
||||
pgid = ref.node.inodes[ref.index].pgid
|
||||
} else {
|
||||
pgid = ref.page.branchPageElement(uint16(ref.index)).pgid
|
||||
}
|
||||
p, n := c.bucket.pageNode(pgid)
|
||||
|
||||
var nextRef = elemRef{page: p, node: n}
|
||||
nextRef.index = nextRef.count() - 1
|
||||
c.stack = append(c.stack, nextRef)
|
||||
}
|
||||
}
|
||||
|
||||
// next moves to the next leaf element and returns the key and value.
|
||||
// If the cursor is at the last leaf element then it stays there and returns nil.
|
||||
func (c *Cursor) next() (key []byte, value []byte, flags uint32) {
|
||||
for {
|
||||
// Attempt to move over one element until we're successful.
|
||||
// Move up the stack as we hit the end of each page in our stack.
|
||||
var i int
|
||||
for i = len(c.stack) - 1; i >= 0; i-- {
|
||||
elem := &c.stack[i]
|
||||
if elem.index < elem.count()-1 {
|
||||
elem.index++
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we've hit the root page then stop and return. This will leave the
|
||||
// cursor on the last element of the last page.
|
||||
if i == -1 {
|
||||
return nil, nil, 0
|
||||
}
|
||||
|
||||
// Otherwise start from where we left off in the stack and find the
|
||||
// first element of the first leaf page.
|
||||
c.stack = c.stack[:i+1]
|
||||
c.first()
|
||||
|
||||
// If this is an empty page then restart and move back up the stack.
|
||||
// https://github.com/boltdb/bolt/issues/450
|
||||
if c.stack[len(c.stack)-1].count() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return c.keyValue()
|
||||
}
|
||||
}
|
||||
|
||||
// search recursively performs a binary search against a given page/node until it finds a given key.
|
||||
func (c *Cursor) search(key []byte, pgid pgid) {
|
||||
p, n := c.bucket.pageNode(pgid)
|
||||
if p != nil && (p.flags&(branchPageFlag|leafPageFlag)) == 0 {
|
||||
panic(fmt.Sprintf("invalid page type: %d: %x", p.id, p.flags))
|
||||
}
|
||||
e := elemRef{page: p, node: n}
|
||||
c.stack = append(c.stack, e)
|
||||
|
||||
// If we're on a leaf page/node then find the specific node.
|
||||
if e.isLeaf() {
|
||||
c.nsearch(key)
|
||||
return
|
||||
}
|
||||
|
||||
if n != nil {
|
||||
c.searchNode(key, n)
|
||||
return
|
||||
}
|
||||
c.searchPage(key, p)
|
||||
}
|
||||
|
||||
func (c *Cursor) searchNode(key []byte, n *node) {
|
||||
var exact bool
|
||||
index := sort.Search(len(n.inodes), func(i int) bool {
|
||||
// TODO(benbjohnson): Optimize this range search. It's a bit hacky right now.
|
||||
// sort.Search() finds the lowest index where f() != -1 but we need the highest index.
|
||||
ret := bytes.Compare(n.inodes[i].key, key)
|
||||
if ret == 0 {
|
||||
exact = true
|
||||
}
|
||||
return ret != -1
|
||||
})
|
||||
if !exact && index > 0 {
|
||||
index--
|
||||
}
|
||||
c.stack[len(c.stack)-1].index = index
|
||||
|
||||
// Recursively search to the next page.
|
||||
c.search(key, n.inodes[index].pgid)
|
||||
}
|
||||
|
||||
func (c *Cursor) searchPage(key []byte, p *page) {
|
||||
// Binary search for the correct range.
|
||||
inodes := p.branchPageElements()
|
||||
|
||||
var exact bool
|
||||
index := sort.Search(int(p.count), func(i int) bool {
|
||||
// TODO(benbjohnson): Optimize this range search. It's a bit hacky right now.
|
||||
// sort.Search() finds the lowest index where f() != -1 but we need the highest index.
|
||||
ret := bytes.Compare(inodes[i].key(), key)
|
||||
if ret == 0 {
|
||||
exact = true
|
||||
}
|
||||
return ret != -1
|
||||
})
|
||||
if !exact && index > 0 {
|
||||
index--
|
||||
}
|
||||
c.stack[len(c.stack)-1].index = index
|
||||
|
||||
// Recursively search to the next page.
|
||||
c.search(key, inodes[index].pgid)
|
||||
}
|
||||
|
||||
// nsearch searches the leaf node on the top of the stack for a key.
|
||||
func (c *Cursor) nsearch(key []byte) {
|
||||
e := &c.stack[len(c.stack)-1]
|
||||
p, n := e.page, e.node
|
||||
|
||||
// If we have a node then search its inodes.
|
||||
if n != nil {
|
||||
index := sort.Search(len(n.inodes), func(i int) bool {
|
||||
return bytes.Compare(n.inodes[i].key, key) != -1
|
||||
})
|
||||
e.index = index
|
||||
return
|
||||
}
|
||||
|
||||
// If we have a page then search its leaf elements.
|
||||
inodes := p.leafPageElements()
|
||||
index := sort.Search(int(p.count), func(i int) bool {
|
||||
return bytes.Compare(inodes[i].key(), key) != -1
|
||||
})
|
||||
e.index = index
|
||||
}
|
||||
|
||||
// keyValue returns the key and value of the current leaf element.
|
||||
func (c *Cursor) keyValue() ([]byte, []byte, uint32) {
|
||||
ref := &c.stack[len(c.stack)-1]
|
||||
if ref.count() == 0 || ref.index >= ref.count() {
|
||||
return nil, nil, 0
|
||||
}
|
||||
|
||||
// Retrieve value from node.
|
||||
if ref.node != nil {
|
||||
inode := &ref.node.inodes[ref.index]
|
||||
return inode.key, inode.value, inode.flags
|
||||
}
|
||||
|
||||
// Or retrieve value from page.
|
||||
elem := ref.page.leafPageElement(uint16(ref.index))
|
||||
return elem.key(), elem.value(), elem.flags
|
||||
}
|
||||
|
||||
// node returns the node that the cursor is currently positioned on.
|
||||
func (c *Cursor) node() *node {
|
||||
_assert(len(c.stack) > 0, "accessing a node with a zero-length cursor stack")
|
||||
|
||||
// If the top of the stack is a leaf node then just return it.
|
||||
if ref := &c.stack[len(c.stack)-1]; ref.node != nil && ref.isLeaf() {
|
||||
return ref.node
|
||||
}
|
||||
|
||||
// Start from root and traverse down the hierarchy.
|
||||
var n = c.stack[0].node
|
||||
if n == nil {
|
||||
n = c.bucket.node(c.stack[0].page.id, nil)
|
||||
}
|
||||
for _, ref := range c.stack[:len(c.stack)-1] {
|
||||
_assert(!n.isLeaf, "expected branch node")
|
||||
n = n.childAt(int(ref.index))
|
||||
}
|
||||
_assert(n.isLeaf, "expected leaf node")
|
||||
return n
|
||||
}
|
||||
|
||||
// elemRef represents a reference to an element on a given page/node.
|
||||
type elemRef struct {
|
||||
page *page
|
||||
node *node
|
||||
index int
|
||||
}
|
||||
|
||||
// isLeaf returns whether the ref is pointing at a leaf page/node.
|
||||
func (r *elemRef) isLeaf() bool {
|
||||
if r.node != nil {
|
||||
return r.node.isLeaf
|
||||
}
|
||||
return (r.page.flags & leafPageFlag) != 0
|
||||
}
|
||||
|
||||
// count returns the number of inodes or page elements.
|
||||
func (r *elemRef) count() int {
|
||||
if r.node != nil {
|
||||
return len(r.node.inodes)
|
||||
}
|
||||
return int(r.page.count)
|
||||
}
|
|
@ -0,0 +1,817 @@
|
|||
package bolt_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
// Ensure that a cursor can return a reference to the bucket that created it.
|
||||
func TestCursor_Bucket(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cb := b.Cursor().Bucket(); !reflect.DeepEqual(cb, b) {
|
||||
t.Fatal("cursor bucket mismatch")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can seek to the appropriate keys.
|
||||
func TestCursor_Seek(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("0001")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("bar"), []byte("0002")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("baz"), []byte("0003")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := b.CreateBucket([]byte("bkt")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
|
||||
// Exact match should go to the key.
|
||||
if k, v := c.Seek([]byte("bar")); !bytes.Equal(k, []byte("bar")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte("0002")) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
// Inexact match should go to the next key.
|
||||
if k, v := c.Seek([]byte("bas")); !bytes.Equal(k, []byte("baz")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte("0003")) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
// Low key should go to the first key.
|
||||
if k, v := c.Seek([]byte("")); !bytes.Equal(k, []byte("bar")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte("0002")) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
// High key should return no key.
|
||||
if k, v := c.Seek([]byte("zzz")); k != nil {
|
||||
t.Fatalf("expected nil key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("expected nil value: %v", v)
|
||||
}
|
||||
|
||||
// Buckets should return their key but no value.
|
||||
if k, v := c.Seek([]byte("bkt")); !bytes.Equal(k, []byte("bkt")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("expected nil value: %v", v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursor_Delete(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
const count = 1000
|
||||
|
||||
// Insert every other key between 0 and $count.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := 0; i < count; i += 1 {
|
||||
k := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(k, uint64(i))
|
||||
if err := b.Put(k, make([]byte, 100)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if _, err := b.CreateBucket([]byte("sub")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
bound := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(bound, uint64(count/2))
|
||||
for key, _ := c.First(); bytes.Compare(key, bound) < 0; key, _ = c.Next() {
|
||||
if err := c.Delete(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Seek([]byte("sub"))
|
||||
if err := c.Delete(); err != bolt.ErrIncompatibleValue {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
stats := tx.Bucket([]byte("widgets")).Stats()
|
||||
if stats.KeyN != count/2+1 {
|
||||
t.Fatalf("unexpected KeyN: %d", stats.KeyN)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can seek to the appropriate keys when there are a
|
||||
// large number of keys. This test also checks that seek will always move
|
||||
// forward to the next key.
|
||||
//
|
||||
// Related: https://github.com/boltdb/bolt/pull/187
|
||||
func TestCursor_Seek_Large(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
var count = 10000
|
||||
|
||||
// Insert every other key between 0 and $count.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < count; i += 100 {
|
||||
for j := i; j < i+100; j += 2 {
|
||||
k := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(k, uint64(j))
|
||||
if err := b.Put(k, make([]byte, 100)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
for i := 0; i < count; i++ {
|
||||
seek := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(seek, uint64(i))
|
||||
|
||||
k, _ := c.Seek(seek)
|
||||
|
||||
// The last seek is beyond the end of the the range so
|
||||
// it should return nil.
|
||||
if i == count-1 {
|
||||
if k != nil {
|
||||
t.Fatal("expected nil key")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise we should seek to the exact key or the next key.
|
||||
num := binary.BigEndian.Uint64(k)
|
||||
if i%2 == 0 {
|
||||
if num != uint64(i) {
|
||||
t.Fatalf("unexpected num: %d", num)
|
||||
}
|
||||
} else {
|
||||
if num != uint64(i+1) {
|
||||
t.Fatalf("unexpected num: %d", num)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a cursor can iterate over an empty bucket without error.
|
||||
func TestCursor_EmptyBucket(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucket([]byte("widgets"))
|
||||
return err
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
k, v := c.First()
|
||||
if k != nil {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can reverse iterate over an empty bucket without error.
|
||||
func TestCursor_EmptyBucketReverse(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucket([]byte("widgets"))
|
||||
return err
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
k, v := c.Last()
|
||||
if k != nil {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can iterate over a single root with a couple elements.
|
||||
func TestCursor_Iterate_Leaf(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("baz"), []byte{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte{0}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("bar"), []byte{1}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tx, err := db.Begin(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
|
||||
k, v := c.First()
|
||||
if !bytes.Equal(k, []byte("bar")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte{1}) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
k, v = c.Next()
|
||||
if !bytes.Equal(k, []byte("baz")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte{}) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
k, v = c.Next()
|
||||
if !bytes.Equal(k, []byte("foo")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte{0}) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
k, v = c.Next()
|
||||
if k != nil {
|
||||
t.Fatalf("expected nil key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("expected nil value: %v", v)
|
||||
}
|
||||
|
||||
k, v = c.Next()
|
||||
if k != nil {
|
||||
t.Fatalf("expected nil key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("expected nil value: %v", v)
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can iterate in reverse over a single root with a couple elements.
|
||||
func TestCursor_LeafRootReverse(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("baz"), []byte{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte{0}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("bar"), []byte{1}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tx, err := db.Begin(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
|
||||
if k, v := c.Last(); !bytes.Equal(k, []byte("foo")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte{0}) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
if k, v := c.Prev(); !bytes.Equal(k, []byte("baz")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte{}) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
if k, v := c.Prev(); !bytes.Equal(k, []byte("bar")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, []byte{1}) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
if k, v := c.Prev(); k != nil {
|
||||
t.Fatalf("expected nil key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("expected nil value: %v", v)
|
||||
}
|
||||
|
||||
if k, v := c.Prev(); k != nil {
|
||||
t.Fatalf("expected nil key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("expected nil value: %v", v)
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can restart from the beginning.
|
||||
func TestCursor_Restart(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("bar"), []byte{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx, err := db.Begin(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
|
||||
if k, _ := c.First(); !bytes.Equal(k, []byte("bar")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
}
|
||||
if k, _ := c.Next(); !bytes.Equal(k, []byte("foo")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
}
|
||||
|
||||
if k, _ := c.First(); !bytes.Equal(k, []byte("bar")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
}
|
||||
if k, _ := c.Next(); !bytes.Equal(k, []byte("foo")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a cursor can skip over empty pages that have been deleted.
|
||||
func TestCursor_First_EmptyPages(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
// Create 1000 keys in the "widgets" bucket.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
if err := b.Put(u64tob(uint64(i)), []byte{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete half the keys and then try to iterate.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte("widgets"))
|
||||
for i := 0; i < 600; i++ {
|
||||
if err := b.Delete(u64tob(uint64(i))); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
c := b.Cursor()
|
||||
var n int
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
n++
|
||||
}
|
||||
if n != 400 {
|
||||
t.Fatalf("unexpected key count: %d", n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx can iterate over all elements in a bucket.
|
||||
func TestCursor_QuickCheck(t *testing.T) {
|
||||
f := func(items testdata) bool {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
// Bulk insert all values.
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, item := range items {
|
||||
if err := b.Put(item.Key, item.Value); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Sort test data.
|
||||
sort.Sort(items)
|
||||
|
||||
// Iterate over all items and check consistency.
|
||||
var index = 0
|
||||
tx, err = db.Begin(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
for k, v := c.First(); k != nil && index < len(items); k, v = c.Next() {
|
||||
if !bytes.Equal(k, items[index].Key) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, items[index].Value) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
index++
|
||||
}
|
||||
if len(items) != index {
|
||||
t.Fatalf("unexpected item count: %v, expected %v", len(items), index)
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
if err := quick.Check(f, qconfig()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a transaction can iterate over all elements in a bucket in reverse.
|
||||
func TestCursor_QuickCheck_Reverse(t *testing.T) {
|
||||
f := func(items testdata) bool {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
// Bulk insert all values.
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, item := range items {
|
||||
if err := b.Put(item.Key, item.Value); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Sort test data.
|
||||
sort.Sort(revtestdata(items))
|
||||
|
||||
// Iterate over all items and check consistency.
|
||||
var index = 0
|
||||
tx, err = db.Begin(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
for k, v := c.Last(); k != nil && index < len(items); k, v = c.Prev() {
|
||||
if !bytes.Equal(k, items[index].Key) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if !bytes.Equal(v, items[index].Value) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
index++
|
||||
}
|
||||
if len(items) != index {
|
||||
t.Fatalf("unexpected item count: %v, expected %v", len(items), index)
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
if err := quick.Check(f, qconfig()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can iterate over subbuckets.
|
||||
func TestCursor_QuickCheck_BucketsOnly(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := b.CreateBucket([]byte("foo")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := b.CreateBucket([]byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := b.CreateBucket([]byte("baz")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
var names []string
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
names = append(names, string(k))
|
||||
if v != nil {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
}
|
||||
if !reflect.DeepEqual(names, []string{"bar", "baz", "foo"}) {
|
||||
t.Fatalf("unexpected names: %+v", names)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx cursor can reverse iterate over subbuckets.
|
||||
func TestCursor_QuickCheck_BucketsOnly_Reverse(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := b.CreateBucket([]byte("foo")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := b.CreateBucket([]byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := b.CreateBucket([]byte("baz")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
var names []string
|
||||
c := tx.Bucket([]byte("widgets")).Cursor()
|
||||
for k, v := c.Last(); k != nil; k, v = c.Prev() {
|
||||
names = append(names, string(k))
|
||||
if v != nil {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
}
|
||||
if !reflect.DeepEqual(names, []string{"foo", "baz", "bar"}) {
|
||||
t.Fatalf("unexpected names: %+v", names)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleCursor() {
|
||||
// Open the database.
|
||||
db, err := bolt.Open(tempfile(), 0666, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove(db.Path())
|
||||
|
||||
// Start a read-write transaction.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
// Create a new bucket.
|
||||
b, err := tx.CreateBucket([]byte("animals"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert data into a bucket.
|
||||
if err := b.Put([]byte("dog"), []byte("fun")); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("cat"), []byte("lame")); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("liger"), []byte("awesome")); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a cursor for iteration.
|
||||
c := b.Cursor()
|
||||
|
||||
// Iterate over items in sorted key order. This starts from the
|
||||
// first key/value pair and updates the k/v variables to the
|
||||
// next key/value on each iteration.
|
||||
//
|
||||
// The loop finishes at the end of the cursor when a nil key is returned.
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
fmt.Printf("A %s is %s.\n", k, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// A cat is lame.
|
||||
// A dog is fun.
|
||||
// A liger is awesome.
|
||||
}
|
||||
|
||||
func ExampleCursor_reverse() {
|
||||
// Open the database.
|
||||
db, err := bolt.Open(tempfile(), 0666, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove(db.Path())
|
||||
|
||||
// Start a read-write transaction.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
// Create a new bucket.
|
||||
b, err := tx.CreateBucket([]byte("animals"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert data into a bucket.
|
||||
if err := b.Put([]byte("dog"), []byte("fun")); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("cat"), []byte("lame")); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("liger"), []byte("awesome")); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a cursor for iteration.
|
||||
c := b.Cursor()
|
||||
|
||||
// Iterate over items in reverse sorted key order. This starts
|
||||
// from the last key/value pair and updates the k/v variables to
|
||||
// the previous key/value on each iteration.
|
||||
//
|
||||
// The loop finishes at the beginning of the cursor when a nil key
|
||||
// is returned.
|
||||
for k, v := c.Last(); k != nil; k, v = c.Prev() {
|
||||
fmt.Printf("A %s is %s.\n", k, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Close the database to release the file lock.
|
||||
if err := db.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// A liger is awesome.
|
||||
// A dog is fun.
|
||||
// A cat is lame.
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
Package bolt implements a low-level key/value store in pure Go. It supports
|
||||
fully serializable transactions, ACID semantics, and lock-free MVCC with
|
||||
multiple readers and a single writer. Bolt can be used for projects that
|
||||
want a simple data store without the need to add large dependencies such as
|
||||
Postgres or MySQL.
|
||||
|
||||
Bolt is a single-level, zero-copy, B+tree data store. This means that Bolt is
|
||||
optimized for fast read access and does not require recovery in the event of a
|
||||
system crash. Transactions which have not finished committing will simply be
|
||||
rolled back in the event of a crash.
|
||||
|
||||
The design of Bolt is based on Howard Chu's LMDB database project.
|
||||
|
||||
Bolt currently works on Windows, Mac OS X, and Linux.
|
||||
|
||||
|
||||
Basics
|
||||
|
||||
There are only a few types in Bolt: DB, Bucket, Tx, and Cursor. The DB is
|
||||
a collection of buckets and is represented by a single file on disk. A bucket is
|
||||
a collection of unique keys that are associated with values.
|
||||
|
||||
Transactions provide either read-only or read-write access to the database.
|
||||
Read-only transactions can retrieve key/value pairs and can use Cursors to
|
||||
iterate over the dataset sequentially. Read-write transactions can create and
|
||||
delete buckets and can insert and remove keys. Only one read-write transaction
|
||||
is allowed at a time.
|
||||
|
||||
|
||||
Caveats
|
||||
|
||||
The database uses a read-only, memory-mapped data file to ensure that
|
||||
applications cannot corrupt the database, however, this means that keys and
|
||||
values returned from Bolt cannot be changed. Writing to a read-only byte slice
|
||||
will cause Go to panic.
|
||||
|
||||
Keys and values retrieved from the database are only valid for the life of
|
||||
the transaction. When used outside the transaction, these byte slices can
|
||||
point to different data or can point to invalid memory which will cause a panic.
|
||||
|
||||
|
||||
*/
|
||||
package bolt
|
|
@ -0,0 +1,71 @@
|
|||
package bolt
|
||||
|
||||
import "errors"
|
||||
|
||||
// These errors can be returned when opening or calling methods on a DB.
|
||||
var (
|
||||
// ErrDatabaseNotOpen is returned when a DB instance is accessed before it
|
||||
// is opened or after it is closed.
|
||||
ErrDatabaseNotOpen = errors.New("database not open")
|
||||
|
||||
// ErrDatabaseOpen is returned when opening a database that is
|
||||
// already open.
|
||||
ErrDatabaseOpen = errors.New("database already open")
|
||||
|
||||
// ErrInvalid is returned when both meta pages on a database are invalid.
|
||||
// This typically occurs when a file is not a bolt database.
|
||||
ErrInvalid = errors.New("invalid database")
|
||||
|
||||
// ErrVersionMismatch is returned when the data file was created with a
|
||||
// different version of Bolt.
|
||||
ErrVersionMismatch = errors.New("version mismatch")
|
||||
|
||||
// ErrChecksum is returned when either meta page checksum does not match.
|
||||
ErrChecksum = errors.New("checksum error")
|
||||
|
||||
// ErrTimeout is returned when a database cannot obtain an exclusive lock
|
||||
// on the data file after the timeout passed to Open().
|
||||
ErrTimeout = errors.New("timeout")
|
||||
)
|
||||
|
||||
// These errors can occur when beginning or committing a Tx.
|
||||
var (
|
||||
// ErrTxNotWritable is returned when performing a write operation on a
|
||||
// read-only transaction.
|
||||
ErrTxNotWritable = errors.New("tx not writable")
|
||||
|
||||
// ErrTxClosed is returned when committing or rolling back a transaction
|
||||
// that has already been committed or rolled back.
|
||||
ErrTxClosed = errors.New("tx closed")
|
||||
|
||||
// ErrDatabaseReadOnly is returned when a mutating transaction is started on a
|
||||
// read-only database.
|
||||
ErrDatabaseReadOnly = errors.New("database is in read-only mode")
|
||||
)
|
||||
|
||||
// These errors can occur when putting or deleting a value or a bucket.
|
||||
var (
|
||||
// ErrBucketNotFound is returned when trying to access a bucket that has
|
||||
// not been created yet.
|
||||
ErrBucketNotFound = errors.New("bucket not found")
|
||||
|
||||
// ErrBucketExists is returned when creating a bucket that already exists.
|
||||
ErrBucketExists = errors.New("bucket already exists")
|
||||
|
||||
// ErrBucketNameRequired is returned when creating a bucket with a blank name.
|
||||
ErrBucketNameRequired = errors.New("bucket name required")
|
||||
|
||||
// ErrKeyRequired is returned when inserting a zero-length key.
|
||||
ErrKeyRequired = errors.New("key required")
|
||||
|
||||
// ErrKeyTooLarge is returned when inserting a key that is larger than MaxKeySize.
|
||||
ErrKeyTooLarge = errors.New("key too large")
|
||||
|
||||
// ErrValueTooLarge is returned when inserting a value that is larger than MaxValueSize.
|
||||
ErrValueTooLarge = errors.New("value too large")
|
||||
|
||||
// ErrIncompatibleValue is returned when trying create or delete a bucket
|
||||
// on an existing non-bucket key or when trying to create or delete a
|
||||
// non-bucket key on an existing bucket key.
|
||||
ErrIncompatibleValue = errors.New("incompatible value")
|
||||
)
|
|
@ -0,0 +1,242 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// freelist represents a list of all pages that are available for allocation.
|
||||
// It also tracks pages that have been freed but are still in use by open transactions.
|
||||
type freelist struct {
|
||||
ids []pgid // all free and available free page ids.
|
||||
pending map[txid][]pgid // mapping of soon-to-be free page ids by tx.
|
||||
cache map[pgid]bool // fast lookup of all free and pending page ids.
|
||||
}
|
||||
|
||||
// newFreelist returns an empty, initialized freelist.
|
||||
func newFreelist() *freelist {
|
||||
return &freelist{
|
||||
pending: make(map[txid][]pgid),
|
||||
cache: make(map[pgid]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// size returns the size of the page after serialization.
|
||||
func (f *freelist) size() int {
|
||||
return pageHeaderSize + (int(unsafe.Sizeof(pgid(0))) * f.count())
|
||||
}
|
||||
|
||||
// count returns count of pages on the freelist
|
||||
func (f *freelist) count() int {
|
||||
return f.free_count() + f.pending_count()
|
||||
}
|
||||
|
||||
// free_count returns count of free pages
|
||||
func (f *freelist) free_count() int {
|
||||
return len(f.ids)
|
||||
}
|
||||
|
||||
// pending_count returns count of pending pages
|
||||
func (f *freelist) pending_count() int {
|
||||
var count int
|
||||
for _, list := range f.pending {
|
||||
count += len(list)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// all returns a list of all free ids and all pending ids in one sorted list.
|
||||
func (f *freelist) all() []pgid {
|
||||
m := make(pgids, 0)
|
||||
|
||||
for _, list := range f.pending {
|
||||
m = append(m, list...)
|
||||
}
|
||||
|
||||
sort.Sort(m)
|
||||
return pgids(f.ids).merge(m)
|
||||
}
|
||||
|
||||
// allocate returns the starting page id of a contiguous list of pages of a given size.
|
||||
// If a contiguous block cannot be found then 0 is returned.
|
||||
func (f *freelist) allocate(n int) pgid {
|
||||
if len(f.ids) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var initial, previd pgid
|
||||
for i, id := range f.ids {
|
||||
if id <= 1 {
|
||||
panic(fmt.Sprintf("invalid page allocation: %d", id))
|
||||
}
|
||||
|
||||
// Reset initial page if this is not contiguous.
|
||||
if previd == 0 || id-previd != 1 {
|
||||
initial = id
|
||||
}
|
||||
|
||||
// If we found a contiguous block then remove it and return it.
|
||||
if (id-initial)+1 == pgid(n) {
|
||||
// If we're allocating off the beginning then take the fast path
|
||||
// and just adjust the existing slice. This will use extra memory
|
||||
// temporarily but the append() in free() will realloc the slice
|
||||
// as is necessary.
|
||||
if (i + 1) == n {
|
||||
f.ids = f.ids[i+1:]
|
||||
} else {
|
||||
copy(f.ids[i-n+1:], f.ids[i+1:])
|
||||
f.ids = f.ids[:len(f.ids)-n]
|
||||
}
|
||||
|
||||
// Remove from the free cache.
|
||||
for i := pgid(0); i < pgid(n); i++ {
|
||||
delete(f.cache, initial+i)
|
||||
}
|
||||
|
||||
return initial
|
||||
}
|
||||
|
||||
previd = id
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// free releases a page and its overflow for a given transaction id.
|
||||
// If the page is already free then a panic will occur.
|
||||
func (f *freelist) free(txid txid, p *page) {
|
||||
if p.id <= 1 {
|
||||
panic(fmt.Sprintf("cannot free page 0 or 1: %d", p.id))
|
||||
}
|
||||
|
||||
// Free page and all its overflow pages.
|
||||
var ids = f.pending[txid]
|
||||
for id := p.id; id <= p.id+pgid(p.overflow); id++ {
|
||||
// Verify that page is not already free.
|
||||
if f.cache[id] {
|
||||
panic(fmt.Sprintf("page %d already freed", id))
|
||||
}
|
||||
|
||||
// Add to the freelist and cache.
|
||||
ids = append(ids, id)
|
||||
f.cache[id] = true
|
||||
}
|
||||
f.pending[txid] = ids
|
||||
}
|
||||
|
||||
// release moves all page ids for a transaction id (or older) to the freelist.
|
||||
func (f *freelist) release(txid txid) {
|
||||
m := make(pgids, 0)
|
||||
for tid, ids := range f.pending {
|
||||
if tid <= txid {
|
||||
// Move transaction's pending pages to the available freelist.
|
||||
// Don't remove from the cache since the page is still free.
|
||||
m = append(m, ids...)
|
||||
delete(f.pending, tid)
|
||||
}
|
||||
}
|
||||
sort.Sort(m)
|
||||
f.ids = pgids(f.ids).merge(m)
|
||||
}
|
||||
|
||||
// rollback removes the pages from a given pending tx.
|
||||
func (f *freelist) rollback(txid txid) {
|
||||
// Remove page ids from cache.
|
||||
for _, id := range f.pending[txid] {
|
||||
delete(f.cache, id)
|
||||
}
|
||||
|
||||
// Remove pages from pending list.
|
||||
delete(f.pending, txid)
|
||||
}
|
||||
|
||||
// freed returns whether a given page is in the free list.
|
||||
func (f *freelist) freed(pgid pgid) bool {
|
||||
return f.cache[pgid]
|
||||
}
|
||||
|
||||
// read initializes the freelist from a freelist page.
|
||||
func (f *freelist) read(p *page) {
|
||||
// If the page.count is at the max uint16 value (64k) then it's considered
|
||||
// an overflow and the size of the freelist is stored as the first element.
|
||||
idx, count := 0, int(p.count)
|
||||
if count == 0xFFFF {
|
||||
idx = 1
|
||||
count = int(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[0])
|
||||
}
|
||||
|
||||
// Copy the list of page ids from the freelist.
|
||||
ids := ((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[idx:count]
|
||||
f.ids = make([]pgid, len(ids))
|
||||
copy(f.ids, ids)
|
||||
|
||||
// Make sure they're sorted.
|
||||
sort.Sort(pgids(f.ids))
|
||||
|
||||
// Rebuild the page cache.
|
||||
f.reindex()
|
||||
}
|
||||
|
||||
// write writes the page ids onto a freelist page. All free and pending ids are
|
||||
// saved to disk since in the event of a program crash, all pending ids will
|
||||
// become free.
|
||||
func (f *freelist) write(p *page) error {
|
||||
// Combine the old free pgids and pgids waiting on an open transaction.
|
||||
ids := f.all()
|
||||
|
||||
// Update the header flag.
|
||||
p.flags |= freelistPageFlag
|
||||
|
||||
// The page.count can only hold up to 64k elements so if we overflow that
|
||||
// number then we handle it by putting the size in the first element.
|
||||
if len(ids) < 0xFFFF {
|
||||
p.count = uint16(len(ids))
|
||||
copy(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[:], ids)
|
||||
} else {
|
||||
p.count = 0xFFFF
|
||||
((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[0] = pgid(len(ids))
|
||||
copy(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[1:], ids)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reload reads the freelist from a page and filters out pending items.
|
||||
func (f *freelist) reload(p *page) {
|
||||
f.read(p)
|
||||
|
||||
// Build a cache of only pending pages.
|
||||
pcache := make(map[pgid]bool)
|
||||
for _, pendingIDs := range f.pending {
|
||||
for _, pendingID := range pendingIDs {
|
||||
pcache[pendingID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check each page in the freelist and build a new available freelist
|
||||
// with any pages not in the pending lists.
|
||||
var a []pgid
|
||||
for _, id := range f.ids {
|
||||
if !pcache[id] {
|
||||
a = append(a, id)
|
||||
}
|
||||
}
|
||||
f.ids = a
|
||||
|
||||
// Once the available list is rebuilt then rebuild the free cache so that
|
||||
// it includes the available and pending free pages.
|
||||
f.reindex()
|
||||
}
|
||||
|
||||
// reindex rebuilds the free cache based on available and pending free lists.
|
||||
func (f *freelist) reindex() {
|
||||
f.cache = make(map[pgid]bool)
|
||||
for _, id := range f.ids {
|
||||
f.cache[id] = true
|
||||
}
|
||||
for _, pendingIDs := range f.pending {
|
||||
for _, pendingID := range pendingIDs {
|
||||
f.cache[pendingID] = true
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,158 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Ensure that a page is added to a transaction's freelist.
|
||||
func TestFreelist_free(t *testing.T) {
|
||||
f := newFreelist()
|
||||
f.free(100, &page{id: 12})
|
||||
if !reflect.DeepEqual([]pgid{12}, f.pending[100]) {
|
||||
t.Fatalf("exp=%v; got=%v", []pgid{12}, f.pending[100])
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a page and its overflow is added to a transaction's freelist.
|
||||
func TestFreelist_free_overflow(t *testing.T) {
|
||||
f := newFreelist()
|
||||
f.free(100, &page{id: 12, overflow: 3})
|
||||
if exp := []pgid{12, 13, 14, 15}; !reflect.DeepEqual(exp, f.pending[100]) {
|
||||
t.Fatalf("exp=%v; got=%v", exp, f.pending[100])
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a transaction's free pages can be released.
|
||||
func TestFreelist_release(t *testing.T) {
|
||||
f := newFreelist()
|
||||
f.free(100, &page{id: 12, overflow: 1})
|
||||
f.free(100, &page{id: 9})
|
||||
f.free(102, &page{id: 39})
|
||||
f.release(100)
|
||||
f.release(101)
|
||||
if exp := []pgid{9, 12, 13}; !reflect.DeepEqual(exp, f.ids) {
|
||||
t.Fatalf("exp=%v; got=%v", exp, f.ids)
|
||||
}
|
||||
|
||||
f.release(102)
|
||||
if exp := []pgid{9, 12, 13, 39}; !reflect.DeepEqual(exp, f.ids) {
|
||||
t.Fatalf("exp=%v; got=%v", exp, f.ids)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a freelist can find contiguous blocks of pages.
|
||||
func TestFreelist_allocate(t *testing.T) {
|
||||
f := &freelist{ids: []pgid{3, 4, 5, 6, 7, 9, 12, 13, 18}}
|
||||
if id := int(f.allocate(3)); id != 3 {
|
||||
t.Fatalf("exp=3; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(1)); id != 6 {
|
||||
t.Fatalf("exp=6; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(3)); id != 0 {
|
||||
t.Fatalf("exp=0; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(2)); id != 12 {
|
||||
t.Fatalf("exp=12; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(1)); id != 7 {
|
||||
t.Fatalf("exp=7; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(0)); id != 0 {
|
||||
t.Fatalf("exp=0; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(0)); id != 0 {
|
||||
t.Fatalf("exp=0; got=%v", id)
|
||||
}
|
||||
if exp := []pgid{9, 18}; !reflect.DeepEqual(exp, f.ids) {
|
||||
t.Fatalf("exp=%v; got=%v", exp, f.ids)
|
||||
}
|
||||
|
||||
if id := int(f.allocate(1)); id != 9 {
|
||||
t.Fatalf("exp=9; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(1)); id != 18 {
|
||||
t.Fatalf("exp=18; got=%v", id)
|
||||
}
|
||||
if id := int(f.allocate(1)); id != 0 {
|
||||
t.Fatalf("exp=0; got=%v", id)
|
||||
}
|
||||
if exp := []pgid{}; !reflect.DeepEqual(exp, f.ids) {
|
||||
t.Fatalf("exp=%v; got=%v", exp, f.ids)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a freelist can deserialize from a freelist page.
|
||||
func TestFreelist_read(t *testing.T) {
|
||||
// Create a page.
|
||||
var buf [4096]byte
|
||||
page := (*page)(unsafe.Pointer(&buf[0]))
|
||||
page.flags = freelistPageFlag
|
||||
page.count = 2
|
||||
|
||||
// Insert 2 page ids.
|
||||
ids := (*[3]pgid)(unsafe.Pointer(&page.ptr))
|
||||
ids[0] = 23
|
||||
ids[1] = 50
|
||||
|
||||
// Deserialize page into a freelist.
|
||||
f := newFreelist()
|
||||
f.read(page)
|
||||
|
||||
// Ensure that there are two page ids in the freelist.
|
||||
if exp := []pgid{23, 50}; !reflect.DeepEqual(exp, f.ids) {
|
||||
t.Fatalf("exp=%v; got=%v", exp, f.ids)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a freelist can serialize into a freelist page.
|
||||
func TestFreelist_write(t *testing.T) {
|
||||
// Create a freelist and write it to a page.
|
||||
var buf [4096]byte
|
||||
f := &freelist{ids: []pgid{12, 39}, pending: make(map[txid][]pgid)}
|
||||
f.pending[100] = []pgid{28, 11}
|
||||
f.pending[101] = []pgid{3}
|
||||
p := (*page)(unsafe.Pointer(&buf[0]))
|
||||
if err := f.write(p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the page back out.
|
||||
f2 := newFreelist()
|
||||
f2.read(p)
|
||||
|
||||
// Ensure that the freelist is correct.
|
||||
// All pages should be present and in reverse order.
|
||||
if exp := []pgid{3, 11, 12, 28, 39}; !reflect.DeepEqual(exp, f2.ids) {
|
||||
t.Fatalf("exp=%v; got=%v", exp, f2.ids)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_FreelistRelease10K(b *testing.B) { benchmark_FreelistRelease(b, 10000) }
|
||||
func Benchmark_FreelistRelease100K(b *testing.B) { benchmark_FreelistRelease(b, 100000) }
|
||||
func Benchmark_FreelistRelease1000K(b *testing.B) { benchmark_FreelistRelease(b, 1000000) }
|
||||
func Benchmark_FreelistRelease10000K(b *testing.B) { benchmark_FreelistRelease(b, 10000000) }
|
||||
|
||||
func benchmark_FreelistRelease(b *testing.B, size int) {
|
||||
ids := randomPgids(size)
|
||||
pending := randomPgids(len(ids) / 400)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
f := &freelist{ids: ids, pending: map[txid][]pgid{1: pending}}
|
||||
f.release(1)
|
||||
}
|
||||
}
|
||||
|
||||
func randomPgids(n int) []pgid {
|
||||
rand.Seed(42)
|
||||
pgids := make(pgids, n)
|
||||
for i := range pgids {
|
||||
pgids[i] = pgid(rand.Int63())
|
||||
}
|
||||
sort.Sort(pgids)
|
||||
return pgids
|
||||
}
|
|
@ -0,0 +1,599 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// node represents an in-memory, deserialized page.
|
||||
type node struct {
|
||||
bucket *Bucket
|
||||
isLeaf bool
|
||||
unbalanced bool
|
||||
spilled bool
|
||||
key []byte
|
||||
pgid pgid
|
||||
parent *node
|
||||
children nodes
|
||||
inodes inodes
|
||||
}
|
||||
|
||||
// root returns the top-level node this node is attached to.
|
||||
func (n *node) root() *node {
|
||||
if n.parent == nil {
|
||||
return n
|
||||
}
|
||||
return n.parent.root()
|
||||
}
|
||||
|
||||
// minKeys returns the minimum number of inodes this node should have.
|
||||
func (n *node) minKeys() int {
|
||||
if n.isLeaf {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
// size returns the size of the node after serialization.
|
||||
func (n *node) size() int {
|
||||
sz, elsz := pageHeaderSize, n.pageElementSize()
|
||||
for i := 0; i < len(n.inodes); i++ {
|
||||
item := &n.inodes[i]
|
||||
sz += elsz + len(item.key) + len(item.value)
|
||||
}
|
||||
return sz
|
||||
}
|
||||
|
||||
// sizeLessThan returns true if the node is less than a given size.
|
||||
// This is an optimization to avoid calculating a large node when we only need
|
||||
// to know if it fits inside a certain page size.
|
||||
func (n *node) sizeLessThan(v int) bool {
|
||||
sz, elsz := pageHeaderSize, n.pageElementSize()
|
||||
for i := 0; i < len(n.inodes); i++ {
|
||||
item := &n.inodes[i]
|
||||
sz += elsz + len(item.key) + len(item.value)
|
||||
if sz >= v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// pageElementSize returns the size of each page element based on the type of node.
|
||||
func (n *node) pageElementSize() int {
|
||||
if n.isLeaf {
|
||||
return leafPageElementSize
|
||||
}
|
||||
return branchPageElementSize
|
||||
}
|
||||
|
||||
// childAt returns the child node at a given index.
|
||||
func (n *node) childAt(index int) *node {
|
||||
if n.isLeaf {
|
||||
panic(fmt.Sprintf("invalid childAt(%d) on a leaf node", index))
|
||||
}
|
||||
return n.bucket.node(n.inodes[index].pgid, n)
|
||||
}
|
||||
|
||||
// childIndex returns the index of a given child node.
|
||||
func (n *node) childIndex(child *node) int {
|
||||
index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, child.key) != -1 })
|
||||
return index
|
||||
}
|
||||
|
||||
// numChildren returns the number of children.
|
||||
func (n *node) numChildren() int {
|
||||
return len(n.inodes)
|
||||
}
|
||||
|
||||
// nextSibling returns the next node with the same parent.
|
||||
func (n *node) nextSibling() *node {
|
||||
if n.parent == nil {
|
||||
return nil
|
||||
}
|
||||
index := n.parent.childIndex(n)
|
||||
if index >= n.parent.numChildren()-1 {
|
||||
return nil
|
||||
}
|
||||
return n.parent.childAt(index + 1)
|
||||
}
|
||||
|
||||
// prevSibling returns the previous node with the same parent.
|
||||
func (n *node) prevSibling() *node {
|
||||
if n.parent == nil {
|
||||
return nil
|
||||
}
|
||||
index := n.parent.childIndex(n)
|
||||
if index == 0 {
|
||||
return nil
|
||||
}
|
||||
return n.parent.childAt(index - 1)
|
||||
}
|
||||
|
||||
// put inserts a key/value.
|
||||
func (n *node) put(oldKey, newKey, value []byte, pgid pgid, flags uint32) {
|
||||
if pgid >= n.bucket.tx.meta.pgid {
|
||||
panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", pgid, n.bucket.tx.meta.pgid))
|
||||
} else if len(oldKey) <= 0 {
|
||||
panic("put: zero-length old key")
|
||||
} else if len(newKey) <= 0 {
|
||||
panic("put: zero-length new key")
|
||||
}
|
||||
|
||||
// Find insertion index.
|
||||
index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, oldKey) != -1 })
|
||||
|
||||
// Add capacity and shift nodes if we don't have an exact match and need to insert.
|
||||
exact := (len(n.inodes) > 0 && index < len(n.inodes) && bytes.Equal(n.inodes[index].key, oldKey))
|
||||
if !exact {
|
||||
n.inodes = append(n.inodes, inode{})
|
||||
copy(n.inodes[index+1:], n.inodes[index:])
|
||||
}
|
||||
|
||||
inode := &n.inodes[index]
|
||||
inode.flags = flags
|
||||
inode.key = newKey
|
||||
inode.value = value
|
||||
inode.pgid = pgid
|
||||
_assert(len(inode.key) > 0, "put: zero-length inode key")
|
||||
}
|
||||
|
||||
// del removes a key from the node.
|
||||
func (n *node) del(key []byte) {
|
||||
// Find index of key.
|
||||
index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, key) != -1 })
|
||||
|
||||
// Exit if the key isn't found.
|
||||
if index >= len(n.inodes) || !bytes.Equal(n.inodes[index].key, key) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete inode from the node.
|
||||
n.inodes = append(n.inodes[:index], n.inodes[index+1:]...)
|
||||
|
||||
// Mark the node as needing rebalancing.
|
||||
n.unbalanced = true
|
||||
}
|
||||
|
||||
// read initializes the node from a page.
|
||||
func (n *node) read(p *page) {
|
||||
n.pgid = p.id
|
||||
n.isLeaf = ((p.flags & leafPageFlag) != 0)
|
||||
n.inodes = make(inodes, int(p.count))
|
||||
|
||||
for i := 0; i < int(p.count); i++ {
|
||||
inode := &n.inodes[i]
|
||||
if n.isLeaf {
|
||||
elem := p.leafPageElement(uint16(i))
|
||||
inode.flags = elem.flags
|
||||
inode.key = elem.key()
|
||||
inode.value = elem.value()
|
||||
} else {
|
||||
elem := p.branchPageElement(uint16(i))
|
||||
inode.pgid = elem.pgid
|
||||
inode.key = elem.key()
|
||||
}
|
||||
_assert(len(inode.key) > 0, "read: zero-length inode key")
|
||||
}
|
||||
|
||||
// Save first key so we can find the node in the parent when we spill.
|
||||
if len(n.inodes) > 0 {
|
||||
n.key = n.inodes[0].key
|
||||
_assert(len(n.key) > 0, "read: zero-length node key")
|
||||
} else {
|
||||
n.key = nil
|
||||
}
|
||||
}
|
||||
|
||||
// write writes the items onto one or more pages.
|
||||
func (n *node) write(p *page) {
|
||||
// Initialize page.
|
||||
if n.isLeaf {
|
||||
p.flags |= leafPageFlag
|
||||
} else {
|
||||
p.flags |= branchPageFlag
|
||||
}
|
||||
|
||||
if len(n.inodes) >= 0xFFFF {
|
||||
panic(fmt.Sprintf("inode overflow: %d (pgid=%d)", len(n.inodes), p.id))
|
||||
}
|
||||
p.count = uint16(len(n.inodes))
|
||||
|
||||
// Loop over each item and write it to the page.
|
||||
b := (*[maxAllocSize]byte)(unsafe.Pointer(&p.ptr))[n.pageElementSize()*len(n.inodes):]
|
||||
for i, item := range n.inodes {
|
||||
_assert(len(item.key) > 0, "write: zero-length inode key")
|
||||
|
||||
// Write the page element.
|
||||
if n.isLeaf {
|
||||
elem := p.leafPageElement(uint16(i))
|
||||
elem.pos = uint32(uintptr(unsafe.Pointer(&b[0])) - uintptr(unsafe.Pointer(elem)))
|
||||
elem.flags = item.flags
|
||||
elem.ksize = uint32(len(item.key))
|
||||
elem.vsize = uint32(len(item.value))
|
||||
} else {
|
||||
elem := p.branchPageElement(uint16(i))
|
||||
elem.pos = uint32(uintptr(unsafe.Pointer(&b[0])) - uintptr(unsafe.Pointer(elem)))
|
||||
elem.ksize = uint32(len(item.key))
|
||||
elem.pgid = item.pgid
|
||||
_assert(elem.pgid != p.id, "write: circular dependency occurred")
|
||||
}
|
||||
|
||||
// If the length of key+value is larger than the max allocation size
|
||||
// then we need to reallocate the byte array pointer.
|
||||
//
|
||||
// See: https://github.com/boltdb/bolt/pull/335
|
||||
klen, vlen := len(item.key), len(item.value)
|
||||
if len(b) < klen+vlen {
|
||||
b = (*[maxAllocSize]byte)(unsafe.Pointer(&b[0]))[:]
|
||||
}
|
||||
|
||||
// Write data for the element to the end of the page.
|
||||
copy(b[0:], item.key)
|
||||
b = b[klen:]
|
||||
copy(b[0:], item.value)
|
||||
b = b[vlen:]
|
||||
}
|
||||
|
||||
// DEBUG ONLY: n.dump()
|
||||
}
|
||||
|
||||
// split breaks up a node into multiple smaller nodes, if appropriate.
|
||||
// This should only be called from the spill() function.
|
||||
func (n *node) split(pageSize int) []*node {
|
||||
var nodes []*node
|
||||
|
||||
node := n
|
||||
for {
|
||||
// Split node into two.
|
||||
a, b := node.splitTwo(pageSize)
|
||||
nodes = append(nodes, a)
|
||||
|
||||
// If we can't split then exit the loop.
|
||||
if b == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Set node to b so it gets split on the next iteration.
|
||||
node = b
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
// splitTwo breaks up a node into two smaller nodes, if appropriate.
|
||||
// This should only be called from the split() function.
|
||||
func (n *node) splitTwo(pageSize int) (*node, *node) {
|
||||
// Ignore the split if the page doesn't have at least enough nodes for
|
||||
// two pages or if the nodes can fit in a single page.
|
||||
if len(n.inodes) <= (minKeysPerPage*2) || n.sizeLessThan(pageSize) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Determine the threshold before starting a new node.
|
||||
var fillPercent = n.bucket.FillPercent
|
||||
if fillPercent < minFillPercent {
|
||||
fillPercent = minFillPercent
|
||||
} else if fillPercent > maxFillPercent {
|
||||
fillPercent = maxFillPercent
|
||||
}
|
||||
threshold := int(float64(pageSize) * fillPercent)
|
||||
|
||||
// Determine split position and sizes of the two pages.
|
||||
splitIndex, _ := n.splitIndex(threshold)
|
||||
|
||||
// Split node into two separate nodes.
|
||||
// If there's no parent then we'll need to create one.
|
||||
if n.parent == nil {
|
||||
n.parent = &node{bucket: n.bucket, children: []*node{n}}
|
||||
}
|
||||
|
||||
// Create a new node and add it to the parent.
|
||||
next := &node{bucket: n.bucket, isLeaf: n.isLeaf, parent: n.parent}
|
||||
n.parent.children = append(n.parent.children, next)
|
||||
|
||||
// Split inodes across two nodes.
|
||||
next.inodes = n.inodes[splitIndex:]
|
||||
n.inodes = n.inodes[:splitIndex]
|
||||
|
||||
// Update the statistics.
|
||||
n.bucket.tx.stats.Split++
|
||||
|
||||
return n, next
|
||||
}
|
||||
|
||||
// splitIndex finds the position where a page will fill a given threshold.
|
||||
// It returns the index as well as the size of the first page.
|
||||
// This is only be called from split().
|
||||
func (n *node) splitIndex(threshold int) (index, sz int) {
|
||||
sz = pageHeaderSize
|
||||
|
||||
// Loop until we only have the minimum number of keys required for the second page.
|
||||
for i := 0; i < len(n.inodes)-minKeysPerPage; i++ {
|
||||
index = i
|
||||
inode := n.inodes[i]
|
||||
elsize := n.pageElementSize() + len(inode.key) + len(inode.value)
|
||||
|
||||
// If we have at least the minimum number of keys and adding another
|
||||
// node would put us over the threshold then exit and return.
|
||||
if i >= minKeysPerPage && sz+elsize > threshold {
|
||||
break
|
||||
}
|
||||
|
||||
// Add the element size to the total size.
|
||||
sz += elsize
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// spill writes the nodes to dirty pages and splits nodes as it goes.
|
||||
// Returns an error if dirty pages cannot be allocated.
|
||||
func (n *node) spill() error {
|
||||
var tx = n.bucket.tx
|
||||
if n.spilled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Spill child nodes first. Child nodes can materialize sibling nodes in
|
||||
// the case of split-merge so we cannot use a range loop. We have to check
|
||||
// the children size on every loop iteration.
|
||||
sort.Sort(n.children)
|
||||
for i := 0; i < len(n.children); i++ {
|
||||
if err := n.children[i].spill(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// We no longer need the child list because it's only used for spill tracking.
|
||||
n.children = nil
|
||||
|
||||
// Split nodes into appropriate sizes. The first node will always be n.
|
||||
var nodes = n.split(tx.db.pageSize)
|
||||
for _, node := range nodes {
|
||||
// Add node's page to the freelist if it's not new.
|
||||
if node.pgid > 0 {
|
||||
tx.db.freelist.free(tx.meta.txid, tx.page(node.pgid))
|
||||
node.pgid = 0
|
||||
}
|
||||
|
||||
// Allocate contiguous space for the node.
|
||||
p, err := tx.allocate((node.size() / tx.db.pageSize) + 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the node.
|
||||
if p.id >= tx.meta.pgid {
|
||||
panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", p.id, tx.meta.pgid))
|
||||
}
|
||||
node.pgid = p.id
|
||||
node.write(p)
|
||||
node.spilled = true
|
||||
|
||||
// Insert into parent inodes.
|
||||
if node.parent != nil {
|
||||
var key = node.key
|
||||
if key == nil {
|
||||
key = node.inodes[0].key
|
||||
}
|
||||
|
||||
node.parent.put(key, node.inodes[0].key, nil, node.pgid, 0)
|
||||
node.key = node.inodes[0].key
|
||||
_assert(len(node.key) > 0, "spill: zero-length node key")
|
||||
}
|
||||
|
||||
// Update the statistics.
|
||||
tx.stats.Spill++
|
||||
}
|
||||
|
||||
// If the root node split and created a new root then we need to spill that
|
||||
// as well. We'll clear out the children to make sure it doesn't try to respill.
|
||||
if n.parent != nil && n.parent.pgid == 0 {
|
||||
n.children = nil
|
||||
return n.parent.spill()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rebalance attempts to combine the node with sibling nodes if the node fill
|
||||
// size is below a threshold or if there are not enough keys.
|
||||
func (n *node) rebalance() {
|
||||
if !n.unbalanced {
|
||||
return
|
||||
}
|
||||
n.unbalanced = false
|
||||
|
||||
// Update statistics.
|
||||
n.bucket.tx.stats.Rebalance++
|
||||
|
||||
// Ignore if node is above threshold (25%) and has enough keys.
|
||||
var threshold = n.bucket.tx.db.pageSize / 4
|
||||
if n.size() > threshold && len(n.inodes) > n.minKeys() {
|
||||
return
|
||||
}
|
||||
|
||||
// Root node has special handling.
|
||||
if n.parent == nil {
|
||||
// If root node is a branch and only has one node then collapse it.
|
||||
if !n.isLeaf && len(n.inodes) == 1 {
|
||||
// Move root's child up.
|
||||
child := n.bucket.node(n.inodes[0].pgid, n)
|
||||
n.isLeaf = child.isLeaf
|
||||
n.inodes = child.inodes[:]
|
||||
n.children = child.children
|
||||
|
||||
// Reparent all child nodes being moved.
|
||||
for _, inode := range n.inodes {
|
||||
if child, ok := n.bucket.nodes[inode.pgid]; ok {
|
||||
child.parent = n
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old child.
|
||||
child.parent = nil
|
||||
delete(n.bucket.nodes, child.pgid)
|
||||
child.free()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// If node has no keys then just remove it.
|
||||
if n.numChildren() == 0 {
|
||||
n.parent.del(n.key)
|
||||
n.parent.removeChild(n)
|
||||
delete(n.bucket.nodes, n.pgid)
|
||||
n.free()
|
||||
n.parent.rebalance()
|
||||
return
|
||||
}
|
||||
|
||||
_assert(n.parent.numChildren() > 1, "parent must have at least 2 children")
|
||||
|
||||
// Destination node is right sibling if idx == 0, otherwise left sibling.
|
||||
var target *node
|
||||
var useNextSibling = (n.parent.childIndex(n) == 0)
|
||||
if useNextSibling {
|
||||
target = n.nextSibling()
|
||||
} else {
|
||||
target = n.prevSibling()
|
||||
}
|
||||
|
||||
// If both this node and the target node are too small then merge them.
|
||||
if useNextSibling {
|
||||
// Reparent all child nodes being moved.
|
||||
for _, inode := range target.inodes {
|
||||
if child, ok := n.bucket.nodes[inode.pgid]; ok {
|
||||
child.parent.removeChild(child)
|
||||
child.parent = n
|
||||
child.parent.children = append(child.parent.children, child)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy over inodes from target and remove target.
|
||||
n.inodes = append(n.inodes, target.inodes...)
|
||||
n.parent.del(target.key)
|
||||
n.parent.removeChild(target)
|
||||
delete(n.bucket.nodes, target.pgid)
|
||||
target.free()
|
||||
} else {
|
||||
// Reparent all child nodes being moved.
|
||||
for _, inode := range n.inodes {
|
||||
if child, ok := n.bucket.nodes[inode.pgid]; ok {
|
||||
child.parent.removeChild(child)
|
||||
child.parent = target
|
||||
child.parent.children = append(child.parent.children, child)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy over inodes to target and remove node.
|
||||
target.inodes = append(target.inodes, n.inodes...)
|
||||
n.parent.del(n.key)
|
||||
n.parent.removeChild(n)
|
||||
delete(n.bucket.nodes, n.pgid)
|
||||
n.free()
|
||||
}
|
||||
|
||||
// Either this node or the target node was deleted from the parent so rebalance it.
|
||||
n.parent.rebalance()
|
||||
}
|
||||
|
||||
// removes a node from the list of in-memory children.
|
||||
// This does not affect the inodes.
|
||||
func (n *node) removeChild(target *node) {
|
||||
for i, child := range n.children {
|
||||
if child == target {
|
||||
n.children = append(n.children[:i], n.children[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dereference causes the node to copy all its inode key/value references to heap memory.
|
||||
// This is required when the mmap is reallocated so inodes are not pointing to stale data.
|
||||
func (n *node) dereference() {
|
||||
if n.key != nil {
|
||||
key := make([]byte, len(n.key))
|
||||
copy(key, n.key)
|
||||
n.key = key
|
||||
_assert(n.pgid == 0 || len(n.key) > 0, "dereference: zero-length node key on existing node")
|
||||
}
|
||||
|
||||
for i := range n.inodes {
|
||||
inode := &n.inodes[i]
|
||||
|
||||
key := make([]byte, len(inode.key))
|
||||
copy(key, inode.key)
|
||||
inode.key = key
|
||||
_assert(len(inode.key) > 0, "dereference: zero-length inode key")
|
||||
|
||||
value := make([]byte, len(inode.value))
|
||||
copy(value, inode.value)
|
||||
inode.value = value
|
||||
}
|
||||
|
||||
// Recursively dereference children.
|
||||
for _, child := range n.children {
|
||||
child.dereference()
|
||||
}
|
||||
|
||||
// Update statistics.
|
||||
n.bucket.tx.stats.NodeDeref++
|
||||
}
|
||||
|
||||
// free adds the node's underlying page to the freelist.
|
||||
func (n *node) free() {
|
||||
if n.pgid != 0 {
|
||||
n.bucket.tx.db.freelist.free(n.bucket.tx.meta.txid, n.bucket.tx.page(n.pgid))
|
||||
n.pgid = 0
|
||||
}
|
||||
}
|
||||
|
||||
// dump writes the contents of the node to STDERR for debugging purposes.
|
||||
/*
|
||||
func (n *node) dump() {
|
||||
// Write node header.
|
||||
var typ = "branch"
|
||||
if n.isLeaf {
|
||||
typ = "leaf"
|
||||
}
|
||||
warnf("[NODE %d {type=%s count=%d}]", n.pgid, typ, len(n.inodes))
|
||||
|
||||
// Write out abbreviated version of each item.
|
||||
for _, item := range n.inodes {
|
||||
if n.isLeaf {
|
||||
if item.flags&bucketLeafFlag != 0 {
|
||||
bucket := (*bucket)(unsafe.Pointer(&item.value[0]))
|
||||
warnf("+L %08x -> (bucket root=%d)", trunc(item.key, 4), bucket.root)
|
||||
} else {
|
||||
warnf("+L %08x -> %08x", trunc(item.key, 4), trunc(item.value, 4))
|
||||
}
|
||||
} else {
|
||||
warnf("+B %08x -> pgid=%d", trunc(item.key, 4), item.pgid)
|
||||
}
|
||||
}
|
||||
warn("")
|
||||
}
|
||||
*/
|
||||
|
||||
type nodes []*node
|
||||
|
||||
func (s nodes) Len() int { return len(s) }
|
||||
func (s nodes) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s nodes) Less(i, j int) bool { return bytes.Compare(s[i].inodes[0].key, s[j].inodes[0].key) == -1 }
|
||||
|
||||
// inode represents an internal node inside of a node.
|
||||
// It can be used to point to elements in a page or point
|
||||
// to an element which hasn't been added to a page yet.
|
||||
type inode struct {
|
||||
flags uint32
|
||||
pgid pgid
|
||||
key []byte
|
||||
value []byte
|
||||
}
|
||||
|
||||
type inodes []inode
|
|
@ -0,0 +1,156 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Ensure that a node can insert a key/value.
|
||||
func TestNode_put(t *testing.T) {
|
||||
n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{meta: &meta{pgid: 1}}}}
|
||||
n.put([]byte("baz"), []byte("baz"), []byte("2"), 0, 0)
|
||||
n.put([]byte("foo"), []byte("foo"), []byte("0"), 0, 0)
|
||||
n.put([]byte("bar"), []byte("bar"), []byte("1"), 0, 0)
|
||||
n.put([]byte("foo"), []byte("foo"), []byte("3"), 0, leafPageFlag)
|
||||
|
||||
if len(n.inodes) != 3 {
|
||||
t.Fatalf("exp=3; got=%d", len(n.inodes))
|
||||
}
|
||||
if k, v := n.inodes[0].key, n.inodes[0].value; string(k) != "bar" || string(v) != "1" {
|
||||
t.Fatalf("exp=<bar,1>; got=<%s,%s>", k, v)
|
||||
}
|
||||
if k, v := n.inodes[1].key, n.inodes[1].value; string(k) != "baz" || string(v) != "2" {
|
||||
t.Fatalf("exp=<baz,2>; got=<%s,%s>", k, v)
|
||||
}
|
||||
if k, v := n.inodes[2].key, n.inodes[2].value; string(k) != "foo" || string(v) != "3" {
|
||||
t.Fatalf("exp=<foo,3>; got=<%s,%s>", k, v)
|
||||
}
|
||||
if n.inodes[2].flags != uint32(leafPageFlag) {
|
||||
t.Fatalf("not a leaf: %d", n.inodes[2].flags)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a node can deserialize from a leaf page.
|
||||
func TestNode_read_LeafPage(t *testing.T) {
|
||||
// Create a page.
|
||||
var buf [4096]byte
|
||||
page := (*page)(unsafe.Pointer(&buf[0]))
|
||||
page.flags = leafPageFlag
|
||||
page.count = 2
|
||||
|
||||
// Insert 2 elements at the beginning. sizeof(leafPageElement) == 16
|
||||
nodes := (*[3]leafPageElement)(unsafe.Pointer(&page.ptr))
|
||||
nodes[0] = leafPageElement{flags: 0, pos: 32, ksize: 3, vsize: 4} // pos = sizeof(leafPageElement) * 2
|
||||
nodes[1] = leafPageElement{flags: 0, pos: 23, ksize: 10, vsize: 3} // pos = sizeof(leafPageElement) + 3 + 4
|
||||
|
||||
// Write data for the nodes at the end.
|
||||
data := (*[4096]byte)(unsafe.Pointer(&nodes[2]))
|
||||
copy(data[:], []byte("barfooz"))
|
||||
copy(data[7:], []byte("helloworldbye"))
|
||||
|
||||
// Deserialize page into a leaf.
|
||||
n := &node{}
|
||||
n.read(page)
|
||||
|
||||
// Check that there are two inodes with correct data.
|
||||
if !n.isLeaf {
|
||||
t.Fatal("expected leaf")
|
||||
}
|
||||
if len(n.inodes) != 2 {
|
||||
t.Fatalf("exp=2; got=%d", len(n.inodes))
|
||||
}
|
||||
if k, v := n.inodes[0].key, n.inodes[0].value; string(k) != "bar" || string(v) != "fooz" {
|
||||
t.Fatalf("exp=<bar,fooz>; got=<%s,%s>", k, v)
|
||||
}
|
||||
if k, v := n.inodes[1].key, n.inodes[1].value; string(k) != "helloworld" || string(v) != "bye" {
|
||||
t.Fatalf("exp=<helloworld,bye>; got=<%s,%s>", k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a node can serialize into a leaf page.
|
||||
func TestNode_write_LeafPage(t *testing.T) {
|
||||
// Create a node.
|
||||
n := &node{isLeaf: true, inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}}
|
||||
n.put([]byte("susy"), []byte("susy"), []byte("que"), 0, 0)
|
||||
n.put([]byte("ricki"), []byte("ricki"), []byte("lake"), 0, 0)
|
||||
n.put([]byte("john"), []byte("john"), []byte("johnson"), 0, 0)
|
||||
|
||||
// Write it to a page.
|
||||
var buf [4096]byte
|
||||
p := (*page)(unsafe.Pointer(&buf[0]))
|
||||
n.write(p)
|
||||
|
||||
// Read the page back in.
|
||||
n2 := &node{}
|
||||
n2.read(p)
|
||||
|
||||
// Check that the two pages are the same.
|
||||
if len(n2.inodes) != 3 {
|
||||
t.Fatalf("exp=3; got=%d", len(n2.inodes))
|
||||
}
|
||||
if k, v := n2.inodes[0].key, n2.inodes[0].value; string(k) != "john" || string(v) != "johnson" {
|
||||
t.Fatalf("exp=<john,johnson>; got=<%s,%s>", k, v)
|
||||
}
|
||||
if k, v := n2.inodes[1].key, n2.inodes[1].value; string(k) != "ricki" || string(v) != "lake" {
|
||||
t.Fatalf("exp=<ricki,lake>; got=<%s,%s>", k, v)
|
||||
}
|
||||
if k, v := n2.inodes[2].key, n2.inodes[2].value; string(k) != "susy" || string(v) != "que" {
|
||||
t.Fatalf("exp=<susy,que>; got=<%s,%s>", k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a node can split into appropriate subgroups.
|
||||
func TestNode_split(t *testing.T) {
|
||||
// Create a node.
|
||||
n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}}
|
||||
n.put([]byte("00000001"), []byte("00000001"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000002"), []byte("00000002"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000003"), []byte("00000003"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000004"), []byte("00000004"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000005"), []byte("00000005"), []byte("0123456701234567"), 0, 0)
|
||||
|
||||
// Split between 2 & 3.
|
||||
n.split(100)
|
||||
|
||||
var parent = n.parent
|
||||
if len(parent.children) != 2 {
|
||||
t.Fatalf("exp=2; got=%d", len(parent.children))
|
||||
}
|
||||
if len(parent.children[0].inodes) != 2 {
|
||||
t.Fatalf("exp=2; got=%d", len(parent.children[0].inodes))
|
||||
}
|
||||
if len(parent.children[1].inodes) != 3 {
|
||||
t.Fatalf("exp=3; got=%d", len(parent.children[1].inodes))
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a page with the minimum number of inodes just returns a single node.
|
||||
func TestNode_split_MinKeys(t *testing.T) {
|
||||
// Create a node.
|
||||
n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}}
|
||||
n.put([]byte("00000001"), []byte("00000001"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000002"), []byte("00000002"), []byte("0123456701234567"), 0, 0)
|
||||
|
||||
// Split.
|
||||
n.split(20)
|
||||
if n.parent != nil {
|
||||
t.Fatalf("expected nil parent")
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a node that has keys that all fit on a page just returns one leaf.
|
||||
func TestNode_split_SinglePage(t *testing.T) {
|
||||
// Create a node.
|
||||
n := &node{inodes: make(inodes, 0), bucket: &Bucket{tx: &Tx{db: &DB{}, meta: &meta{pgid: 1}}}}
|
||||
n.put([]byte("00000001"), []byte("00000001"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000002"), []byte("00000002"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000003"), []byte("00000003"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000004"), []byte("00000004"), []byte("0123456701234567"), 0, 0)
|
||||
n.put([]byte("00000005"), []byte("00000005"), []byte("0123456701234567"), 0, 0)
|
||||
|
||||
// Split.
|
||||
n.split(4096)
|
||||
if n.parent != nil {
|
||||
t.Fatalf("expected nil parent")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const pageHeaderSize = int(unsafe.Offsetof(((*page)(nil)).ptr))
|
||||
|
||||
const minKeysPerPage = 2
|
||||
|
||||
const branchPageElementSize = int(unsafe.Sizeof(branchPageElement{}))
|
||||
const leafPageElementSize = int(unsafe.Sizeof(leafPageElement{}))
|
||||
|
||||
const (
|
||||
branchPageFlag = 0x01
|
||||
leafPageFlag = 0x02
|
||||
metaPageFlag = 0x04
|
||||
freelistPageFlag = 0x10
|
||||
)
|
||||
|
||||
const (
|
||||
bucketLeafFlag = 0x01
|
||||
)
|
||||
|
||||
type pgid uint64
|
||||
|
||||
type page struct {
|
||||
id pgid
|
||||
flags uint16
|
||||
count uint16
|
||||
overflow uint32
|
||||
ptr uintptr
|
||||
}
|
||||
|
||||
// typ returns a human readable page type string used for debugging.
|
||||
func (p *page) typ() string {
|
||||
if (p.flags & branchPageFlag) != 0 {
|
||||
return "branch"
|
||||
} else if (p.flags & leafPageFlag) != 0 {
|
||||
return "leaf"
|
||||
} else if (p.flags & metaPageFlag) != 0 {
|
||||
return "meta"
|
||||
} else if (p.flags & freelistPageFlag) != 0 {
|
||||
return "freelist"
|
||||
}
|
||||
return fmt.Sprintf("unknown<%02x>", p.flags)
|
||||
}
|
||||
|
||||
// meta returns a pointer to the metadata section of the page.
|
||||
func (p *page) meta() *meta {
|
||||
return (*meta)(unsafe.Pointer(&p.ptr))
|
||||
}
|
||||
|
||||
// leafPageElement retrieves the leaf node by index
|
||||
func (p *page) leafPageElement(index uint16) *leafPageElement {
|
||||
n := &((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[index]
|
||||
return n
|
||||
}
|
||||
|
||||
// leafPageElements retrieves a list of leaf nodes.
|
||||
func (p *page) leafPageElements() []leafPageElement {
|
||||
return ((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[:]
|
||||
}
|
||||
|
||||
// branchPageElement retrieves the branch node by index
|
||||
func (p *page) branchPageElement(index uint16) *branchPageElement {
|
||||
return &((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[index]
|
||||
}
|
||||
|
||||
// branchPageElements retrieves a list of branch nodes.
|
||||
func (p *page) branchPageElements() []branchPageElement {
|
||||
return ((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[:]
|
||||
}
|
||||
|
||||
// dump writes n bytes of the page to STDERR as hex output.
|
||||
func (p *page) hexdump(n int) {
|
||||
buf := (*[maxAllocSize]byte)(unsafe.Pointer(p))[:n]
|
||||
fmt.Fprintf(os.Stderr, "%x\n", buf)
|
||||
}
|
||||
|
||||
type pages []*page
|
||||
|
||||
func (s pages) Len() int { return len(s) }
|
||||
func (s pages) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s pages) Less(i, j int) bool { return s[i].id < s[j].id }
|
||||
|
||||
// branchPageElement represents a node on a branch page.
|
||||
type branchPageElement struct {
|
||||
pos uint32
|
||||
ksize uint32
|
||||
pgid pgid
|
||||
}
|
||||
|
||||
// key returns a byte slice of the node key.
|
||||
func (n *branchPageElement) key() []byte {
|
||||
buf := (*[maxAllocSize]byte)(unsafe.Pointer(n))
|
||||
return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize]
|
||||
}
|
||||
|
||||
// leafPageElement represents a node on a leaf page.
|
||||
type leafPageElement struct {
|
||||
flags uint32
|
||||
pos uint32
|
||||
ksize uint32
|
||||
vsize uint32
|
||||
}
|
||||
|
||||
// key returns a byte slice of the node key.
|
||||
func (n *leafPageElement) key() []byte {
|
||||
buf := (*[maxAllocSize]byte)(unsafe.Pointer(n))
|
||||
return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize:n.ksize]
|
||||
}
|
||||
|
||||
// value returns a byte slice of the node value.
|
||||
func (n *leafPageElement) value() []byte {
|
||||
buf := (*[maxAllocSize]byte)(unsafe.Pointer(n))
|
||||
return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos+n.ksize]))[:n.vsize:n.vsize]
|
||||
}
|
||||
|
||||
// PageInfo represents human readable information about a page.
|
||||
type PageInfo struct {
|
||||
ID int
|
||||
Type string
|
||||
Count int
|
||||
OverflowCount int
|
||||
}
|
||||
|
||||
type pgids []pgid
|
||||
|
||||
func (s pgids) Len() int { return len(s) }
|
||||
func (s pgids) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s pgids) Less(i, j int) bool { return s[i] < s[j] }
|
||||
|
||||
// merge returns the sorted union of a and b.
|
||||
func (a pgids) merge(b pgids) pgids {
|
||||
// Return the opposite slice if one is nil.
|
||||
if len(a) == 0 {
|
||||
return b
|
||||
} else if len(b) == 0 {
|
||||
return a
|
||||
}
|
||||
|
||||
// Create a list to hold all elements from both lists.
|
||||
merged := make(pgids, 0, len(a)+len(b))
|
||||
|
||||
// Assign lead to the slice with a lower starting value, follow to the higher value.
|
||||
lead, follow := a, b
|
||||
if b[0] < a[0] {
|
||||
lead, follow = b, a
|
||||
}
|
||||
|
||||
// Continue while there are elements in the lead.
|
||||
for len(lead) > 0 {
|
||||
// Merge largest prefix of lead that is ahead of follow[0].
|
||||
n := sort.Search(len(lead), func(i int) bool { return lead[i] > follow[0] })
|
||||
merged = append(merged, lead[:n]...)
|
||||
if n >= len(lead) {
|
||||
break
|
||||
}
|
||||
|
||||
// Swap lead and follow.
|
||||
lead, follow = follow, lead[n:]
|
||||
}
|
||||
|
||||
// Append what's left in follow.
|
||||
merged = append(merged, follow...)
|
||||
|
||||
return merged
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
)
|
||||
|
||||
// Ensure that the page type can be returned in human readable format.
|
||||
func TestPage_typ(t *testing.T) {
|
||||
if typ := (&page{flags: branchPageFlag}).typ(); typ != "branch" {
|
||||
t.Fatalf("exp=branch; got=%v", typ)
|
||||
}
|
||||
if typ := (&page{flags: leafPageFlag}).typ(); typ != "leaf" {
|
||||
t.Fatalf("exp=leaf; got=%v", typ)
|
||||
}
|
||||
if typ := (&page{flags: metaPageFlag}).typ(); typ != "meta" {
|
||||
t.Fatalf("exp=meta; got=%v", typ)
|
||||
}
|
||||
if typ := (&page{flags: freelistPageFlag}).typ(); typ != "freelist" {
|
||||
t.Fatalf("exp=freelist; got=%v", typ)
|
||||
}
|
||||
if typ := (&page{flags: 20000}).typ(); typ != "unknown<4e20>" {
|
||||
t.Fatalf("exp=unknown<4e20>; got=%v", typ)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the hexdump debugging function doesn't blow up.
|
||||
func TestPage_dump(t *testing.T) {
|
||||
(&page{id: 256}).hexdump(16)
|
||||
}
|
||||
|
||||
func TestPgids_merge(t *testing.T) {
|
||||
a := pgids{4, 5, 6, 10, 11, 12, 13, 27}
|
||||
b := pgids{1, 3, 8, 9, 25, 30}
|
||||
c := a.merge(b)
|
||||
if !reflect.DeepEqual(c, pgids{1, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 25, 27, 30}) {
|
||||
t.Errorf("mismatch: %v", c)
|
||||
}
|
||||
|
||||
a = pgids{4, 5, 6, 10, 11, 12, 13, 27, 35, 36}
|
||||
b = pgids{8, 9, 25, 30}
|
||||
c = a.merge(b)
|
||||
if !reflect.DeepEqual(c, pgids{4, 5, 6, 8, 9, 10, 11, 12, 13, 25, 27, 30, 35, 36}) {
|
||||
t.Errorf("mismatch: %v", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgids_merge_quick(t *testing.T) {
|
||||
if err := quick.Check(func(a, b pgids) bool {
|
||||
// Sort incoming lists.
|
||||
sort.Sort(a)
|
||||
sort.Sort(b)
|
||||
|
||||
// Merge the two lists together.
|
||||
got := a.merge(b)
|
||||
|
||||
// The expected value should be the two lists combined and sorted.
|
||||
exp := append(a, b...)
|
||||
sort.Sort(exp)
|
||||
|
||||
if !reflect.DeepEqual(exp, got) {
|
||||
t.Errorf("\nexp=%+v\ngot=%+v\n", exp, got)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package bolt_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing/quick"
|
||||
"time"
|
||||
)
|
||||
|
||||
// testing/quick defaults to 5 iterations and a random seed.
|
||||
// You can override these settings from the command line:
|
||||
//
|
||||
// -quick.count The number of iterations to perform.
|
||||
// -quick.seed The seed to use for randomizing.
|
||||
// -quick.maxitems The maximum number of items to insert into a DB.
|
||||
// -quick.maxksize The maximum size of a key.
|
||||
// -quick.maxvsize The maximum size of a value.
|
||||
//
|
||||
|
||||
var qcount, qseed, qmaxitems, qmaxksize, qmaxvsize int
|
||||
|
||||
func init() {
|
||||
flag.IntVar(&qcount, "quick.count", 5, "")
|
||||
flag.IntVar(&qseed, "quick.seed", int(time.Now().UnixNano())%100000, "")
|
||||
flag.IntVar(&qmaxitems, "quick.maxitems", 1000, "")
|
||||
flag.IntVar(&qmaxksize, "quick.maxksize", 1024, "")
|
||||
flag.IntVar(&qmaxvsize, "quick.maxvsize", 1024, "")
|
||||
flag.Parse()
|
||||
fmt.Fprintln(os.Stderr, "seed:", qseed)
|
||||
fmt.Fprintf(os.Stderr, "quick settings: count=%v, items=%v, ksize=%v, vsize=%v\n", qcount, qmaxitems, qmaxksize, qmaxvsize)
|
||||
}
|
||||
|
||||
func qconfig() *quick.Config {
|
||||
return &quick.Config{
|
||||
MaxCount: qcount,
|
||||
Rand: rand.New(rand.NewSource(int64(qseed))),
|
||||
}
|
||||
}
|
||||
|
||||
type testdata []testdataitem
|
||||
|
||||
func (t testdata) Len() int { return len(t) }
|
||||
func (t testdata) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
|
||||
func (t testdata) Less(i, j int) bool { return bytes.Compare(t[i].Key, t[j].Key) == -1 }
|
||||
|
||||
func (t testdata) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
n := rand.Intn(qmaxitems-1) + 1
|
||||
items := make(testdata, n)
|
||||
for i := 0; i < n; i++ {
|
||||
item := &items[i]
|
||||
item.Key = randByteSlice(rand, 1, qmaxksize)
|
||||
item.Value = randByteSlice(rand, 0, qmaxvsize)
|
||||
}
|
||||
return reflect.ValueOf(items)
|
||||
}
|
||||
|
||||
type revtestdata []testdataitem
|
||||
|
||||
func (t revtestdata) Len() int { return len(t) }
|
||||
func (t revtestdata) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
|
||||
func (t revtestdata) Less(i, j int) bool { return bytes.Compare(t[i].Key, t[j].Key) == 1 }
|
||||
|
||||
type testdataitem struct {
|
||||
Key []byte
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func randByteSlice(rand *rand.Rand, minSize, maxSize int) []byte {
|
||||
n := rand.Intn(maxSize-minSize) + minSize
|
||||
b := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
b[i] = byte(rand.Intn(255))
|
||||
}
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,329 @@
|
|||
package bolt_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
func TestSimulate_1op_1p(t *testing.T) { testSimulate(t, 1, 1) }
|
||||
func TestSimulate_10op_1p(t *testing.T) { testSimulate(t, 10, 1) }
|
||||
func TestSimulate_100op_1p(t *testing.T) { testSimulate(t, 100, 1) }
|
||||
func TestSimulate_1000op_1p(t *testing.T) { testSimulate(t, 1000, 1) }
|
||||
func TestSimulate_10000op_1p(t *testing.T) { testSimulate(t, 10000, 1) }
|
||||
|
||||
func TestSimulate_10op_10p(t *testing.T) { testSimulate(t, 10, 10) }
|
||||
func TestSimulate_100op_10p(t *testing.T) { testSimulate(t, 100, 10) }
|
||||
func TestSimulate_1000op_10p(t *testing.T) { testSimulate(t, 1000, 10) }
|
||||
func TestSimulate_10000op_10p(t *testing.T) { testSimulate(t, 10000, 10) }
|
||||
|
||||
func TestSimulate_100op_100p(t *testing.T) { testSimulate(t, 100, 100) }
|
||||
func TestSimulate_1000op_100p(t *testing.T) { testSimulate(t, 1000, 100) }
|
||||
func TestSimulate_10000op_100p(t *testing.T) { testSimulate(t, 10000, 100) }
|
||||
|
||||
func TestSimulate_10000op_1000p(t *testing.T) { testSimulate(t, 10000, 1000) }
|
||||
|
||||
// Randomly generate operations on a given database with multiple clients to ensure consistency and thread safety.
|
||||
func testSimulate(t *testing.T, threadCount, parallelism int) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
rand.Seed(int64(qseed))
|
||||
|
||||
// A list of operations that readers and writers can perform.
|
||||
var readerHandlers = []simulateHandler{simulateGetHandler}
|
||||
var writerHandlers = []simulateHandler{simulateGetHandler, simulatePutHandler}
|
||||
|
||||
var versions = make(map[int]*QuickDB)
|
||||
versions[1] = NewQuickDB()
|
||||
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
var mutex sync.Mutex
|
||||
|
||||
// Run n threads in parallel, each with their own operation.
|
||||
var wg sync.WaitGroup
|
||||
var threads = make(chan bool, parallelism)
|
||||
var i int
|
||||
for {
|
||||
threads <- true
|
||||
wg.Add(1)
|
||||
writable := ((rand.Int() % 100) < 20) // 20% writers
|
||||
|
||||
// Choose an operation to execute.
|
||||
var handler simulateHandler
|
||||
if writable {
|
||||
handler = writerHandlers[rand.Intn(len(writerHandlers))]
|
||||
} else {
|
||||
handler = readerHandlers[rand.Intn(len(readerHandlers))]
|
||||
}
|
||||
|
||||
// Execute a thread for the given operation.
|
||||
go func(writable bool, handler simulateHandler) {
|
||||
defer wg.Done()
|
||||
|
||||
// Start transaction.
|
||||
tx, err := db.Begin(writable)
|
||||
if err != nil {
|
||||
t.Fatal("tx begin: ", err)
|
||||
}
|
||||
|
||||
// Obtain current state of the dataset.
|
||||
mutex.Lock()
|
||||
var qdb = versions[tx.ID()]
|
||||
if writable {
|
||||
qdb = versions[tx.ID()-1].Copy()
|
||||
}
|
||||
mutex.Unlock()
|
||||
|
||||
// Make sure we commit/rollback the tx at the end and update the state.
|
||||
if writable {
|
||||
defer func() {
|
||||
mutex.Lock()
|
||||
versions[tx.ID()] = qdb
|
||||
mutex.Unlock()
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
}
|
||||
|
||||
// Ignore operation if we don't have data yet.
|
||||
if qdb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Execute handler.
|
||||
handler(tx, qdb)
|
||||
|
||||
// Release a thread back to the scheduling loop.
|
||||
<-threads
|
||||
}(writable, handler)
|
||||
|
||||
i++
|
||||
if i > threadCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until all threads are done.
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type simulateHandler func(tx *bolt.Tx, qdb *QuickDB)
|
||||
|
||||
// Retrieves a key from the database and verifies that it is what is expected.
|
||||
func simulateGetHandler(tx *bolt.Tx, qdb *QuickDB) {
|
||||
// Randomly retrieve an existing exist.
|
||||
keys := qdb.Rand()
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve root bucket.
|
||||
b := tx.Bucket(keys[0])
|
||||
if b == nil {
|
||||
panic(fmt.Sprintf("bucket[0] expected: %08x\n", trunc(keys[0], 4)))
|
||||
}
|
||||
|
||||
// Drill into nested buckets.
|
||||
for _, key := range keys[1 : len(keys)-1] {
|
||||
b = b.Bucket(key)
|
||||
if b == nil {
|
||||
panic(fmt.Sprintf("bucket[n] expected: %v -> %v\n", keys, key))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify key/value on the final bucket.
|
||||
expected := qdb.Get(keys)
|
||||
actual := b.Get(keys[len(keys)-1])
|
||||
if !bytes.Equal(actual, expected) {
|
||||
fmt.Println("=== EXPECTED ===")
|
||||
fmt.Println(expected)
|
||||
fmt.Println("=== ACTUAL ===")
|
||||
fmt.Println(actual)
|
||||
fmt.Println("=== END ===")
|
||||
panic("value mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Inserts a key into the database.
|
||||
func simulatePutHandler(tx *bolt.Tx, qdb *QuickDB) {
|
||||
var err error
|
||||
keys, value := randKeys(), randValue()
|
||||
|
||||
// Retrieve root bucket.
|
||||
b := tx.Bucket(keys[0])
|
||||
if b == nil {
|
||||
b, err = tx.CreateBucket(keys[0])
|
||||
if err != nil {
|
||||
panic("create bucket: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Create nested buckets, if necessary.
|
||||
for _, key := range keys[1 : len(keys)-1] {
|
||||
child := b.Bucket(key)
|
||||
if child != nil {
|
||||
b = child
|
||||
} else {
|
||||
b, err = b.CreateBucket(key)
|
||||
if err != nil {
|
||||
panic("create bucket: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert into database.
|
||||
if err := b.Put(keys[len(keys)-1], value); err != nil {
|
||||
panic("put: " + err.Error())
|
||||
}
|
||||
|
||||
// Insert into in-memory database.
|
||||
qdb.Put(keys, value)
|
||||
}
|
||||
|
||||
// QuickDB is an in-memory database that replicates the functionality of the
|
||||
// Bolt DB type except that it is entirely in-memory. It is meant for testing
|
||||
// that the Bolt database is consistent.
|
||||
type QuickDB struct {
|
||||
sync.RWMutex
|
||||
m map[string]interface{}
|
||||
}
|
||||
|
||||
// NewQuickDB returns an instance of QuickDB.
|
||||
func NewQuickDB() *QuickDB {
|
||||
return &QuickDB{m: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
// Get retrieves the value at a key path.
|
||||
func (db *QuickDB) Get(keys [][]byte) []byte {
|
||||
db.RLock()
|
||||
defer db.RUnlock()
|
||||
|
||||
m := db.m
|
||||
for _, key := range keys[:len(keys)-1] {
|
||||
value := m[string(key)]
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
switch value := value.(type) {
|
||||
case map[string]interface{}:
|
||||
m = value
|
||||
case []byte:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Only return if it's a simple value.
|
||||
if value, ok := m[string(keys[len(keys)-1])].([]byte); ok {
|
||||
return value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put inserts a value into a key path.
|
||||
func (db *QuickDB) Put(keys [][]byte, value []byte) {
|
||||
db.Lock()
|
||||
defer db.Unlock()
|
||||
|
||||
// Build buckets all the way down the key path.
|
||||
m := db.m
|
||||
for _, key := range keys[:len(keys)-1] {
|
||||
if _, ok := m[string(key)].([]byte); ok {
|
||||
return // Keypath intersects with a simple value. Do nothing.
|
||||
}
|
||||
|
||||
if m[string(key)] == nil {
|
||||
m[string(key)] = make(map[string]interface{})
|
||||
}
|
||||
m = m[string(key)].(map[string]interface{})
|
||||
}
|
||||
|
||||
// Insert value into the last key.
|
||||
m[string(keys[len(keys)-1])] = value
|
||||
}
|
||||
|
||||
// Rand returns a random key path that points to a simple value.
|
||||
func (db *QuickDB) Rand() [][]byte {
|
||||
db.RLock()
|
||||
defer db.RUnlock()
|
||||
if len(db.m) == 0 {
|
||||
return nil
|
||||
}
|
||||
var keys [][]byte
|
||||
db.rand(db.m, &keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func (db *QuickDB) rand(m map[string]interface{}, keys *[][]byte) {
|
||||
i, index := 0, rand.Intn(len(m))
|
||||
for k, v := range m {
|
||||
if i == index {
|
||||
*keys = append(*keys, []byte(k))
|
||||
if v, ok := v.(map[string]interface{}); ok {
|
||||
db.rand(v, keys)
|
||||
}
|
||||
return
|
||||
}
|
||||
i++
|
||||
}
|
||||
panic("quickdb rand: out-of-range")
|
||||
}
|
||||
|
||||
// Copy copies the entire database.
|
||||
func (db *QuickDB) Copy() *QuickDB {
|
||||
db.RLock()
|
||||
defer db.RUnlock()
|
||||
return &QuickDB{m: db.copy(db.m)}
|
||||
}
|
||||
|
||||
func (db *QuickDB) copy(m map[string]interface{}) map[string]interface{} {
|
||||
clone := make(map[string]interface{}, len(m))
|
||||
for k, v := range m {
|
||||
switch v := v.(type) {
|
||||
case map[string]interface{}:
|
||||
clone[k] = db.copy(v)
|
||||
default:
|
||||
clone[k] = v
|
||||
}
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func randKey() []byte {
|
||||
var min, max = 1, 1024
|
||||
n := rand.Intn(max-min) + min
|
||||
b := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
b[i] = byte(rand.Intn(255))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func randKeys() [][]byte {
|
||||
var keys [][]byte
|
||||
var count = rand.Intn(2) + 2
|
||||
for i := 0; i < count; i++ {
|
||||
keys = append(keys, randKey())
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func randValue() []byte {
|
||||
n := rand.Intn(8192)
|
||||
b := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
b[i] = byte(rand.Intn(255))
|
||||
}
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,682 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// txid represents the internal transaction identifier.
|
||||
type txid uint64
|
||||
|
||||
// Tx represents a read-only or read/write transaction on the database.
|
||||
// Read-only transactions can be used for retrieving values for keys and creating cursors.
|
||||
// Read/write transactions can create and remove buckets and create and remove keys.
|
||||
//
|
||||
// IMPORTANT: You must commit or rollback transactions when you are done with
|
||||
// them. Pages can not be reclaimed by the writer until no more transactions
|
||||
// are using them. A long running read transaction can cause the database to
|
||||
// quickly grow.
|
||||
type Tx struct {
|
||||
writable bool
|
||||
managed bool
|
||||
db *DB
|
||||
meta *meta
|
||||
root Bucket
|
||||
pages map[pgid]*page
|
||||
stats TxStats
|
||||
commitHandlers []func()
|
||||
|
||||
// WriteFlag specifies the flag for write-related methods like WriteTo().
|
||||
// Tx opens the database file with the specified flag to copy the data.
|
||||
//
|
||||
// By default, the flag is unset, which works well for mostly in-memory
|
||||
// workloads. For databases that are much larger than available RAM,
|
||||
// set the flag to syscall.O_DIRECT to avoid trashing the page cache.
|
||||
WriteFlag int
|
||||
}
|
||||
|
||||
// init initializes the transaction.
|
||||
func (tx *Tx) init(db *DB) {
|
||||
tx.db = db
|
||||
tx.pages = nil
|
||||
|
||||
// Copy the meta page since it can be changed by the writer.
|
||||
tx.meta = &meta{}
|
||||
db.meta().copy(tx.meta)
|
||||
|
||||
// Copy over the root bucket.
|
||||
tx.root = newBucket(tx)
|
||||
tx.root.bucket = &bucket{}
|
||||
*tx.root.bucket = tx.meta.root
|
||||
|
||||
// Increment the transaction id and add a page cache for writable transactions.
|
||||
if tx.writable {
|
||||
tx.pages = make(map[pgid]*page)
|
||||
tx.meta.txid += txid(1)
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the transaction id.
|
||||
func (tx *Tx) ID() int {
|
||||
return int(tx.meta.txid)
|
||||
}
|
||||
|
||||
// DB returns a reference to the database that created the transaction.
|
||||
func (tx *Tx) DB() *DB {
|
||||
return tx.db
|
||||
}
|
||||
|
||||
// Size returns current database size in bytes as seen by this transaction.
|
||||
func (tx *Tx) Size() int64 {
|
||||
return int64(tx.meta.pgid) * int64(tx.db.pageSize)
|
||||
}
|
||||
|
||||
// Writable returns whether the transaction can perform write operations.
|
||||
func (tx *Tx) Writable() bool {
|
||||
return tx.writable
|
||||
}
|
||||
|
||||
// Cursor creates a cursor associated with the root bucket.
|
||||
// All items in the cursor will return a nil value because all root bucket keys point to buckets.
|
||||
// The cursor is only valid as long as the transaction is open.
|
||||
// Do not use a cursor after the transaction is closed.
|
||||
func (tx *Tx) Cursor() *Cursor {
|
||||
return tx.root.Cursor()
|
||||
}
|
||||
|
||||
// Stats retrieves a copy of the current transaction statistics.
|
||||
func (tx *Tx) Stats() TxStats {
|
||||
return tx.stats
|
||||
}
|
||||
|
||||
// Bucket retrieves a bucket by name.
|
||||
// Returns nil if the bucket does not exist.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (tx *Tx) Bucket(name []byte) *Bucket {
|
||||
return tx.root.Bucket(name)
|
||||
}
|
||||
|
||||
// CreateBucket creates a new bucket.
|
||||
// Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (tx *Tx) CreateBucket(name []byte) (*Bucket, error) {
|
||||
return tx.root.CreateBucket(name)
|
||||
}
|
||||
|
||||
// CreateBucketIfNotExists creates a new bucket if it doesn't already exist.
|
||||
// Returns an error if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (tx *Tx) CreateBucketIfNotExists(name []byte) (*Bucket, error) {
|
||||
return tx.root.CreateBucketIfNotExists(name)
|
||||
}
|
||||
|
||||
// DeleteBucket deletes a bucket.
|
||||
// Returns an error if the bucket cannot be found or if the key represents a non-bucket value.
|
||||
func (tx *Tx) DeleteBucket(name []byte) error {
|
||||
return tx.root.DeleteBucket(name)
|
||||
}
|
||||
|
||||
// ForEach executes a function for each bucket in the root.
|
||||
// If the provided function returns an error then the iteration is stopped and
|
||||
// the error is returned to the caller.
|
||||
func (tx *Tx) ForEach(fn func(name []byte, b *Bucket) error) error {
|
||||
return tx.root.ForEach(func(k, v []byte) error {
|
||||
if err := fn(k, tx.root.Bucket(k)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// OnCommit adds a handler function to be executed after the transaction successfully commits.
|
||||
func (tx *Tx) OnCommit(fn func()) {
|
||||
tx.commitHandlers = append(tx.commitHandlers, fn)
|
||||
}
|
||||
|
||||
// Commit writes all changes to disk and updates the meta page.
|
||||
// Returns an error if a disk write error occurs, or if Commit is
|
||||
// called on a read-only transaction.
|
||||
func (tx *Tx) Commit() error {
|
||||
_assert(!tx.managed, "managed tx commit not allowed")
|
||||
if tx.db == nil {
|
||||
return ErrTxClosed
|
||||
} else if !tx.writable {
|
||||
return ErrTxNotWritable
|
||||
}
|
||||
|
||||
// TODO(benbjohnson): Use vectorized I/O to write out dirty pages.
|
||||
|
||||
// Rebalance nodes which have had deletions.
|
||||
var startTime = time.Now()
|
||||
tx.root.rebalance()
|
||||
if tx.stats.Rebalance > 0 {
|
||||
tx.stats.RebalanceTime += time.Since(startTime)
|
||||
}
|
||||
|
||||
// spill data onto dirty pages.
|
||||
startTime = time.Now()
|
||||
if err := tx.root.spill(); err != nil {
|
||||
tx.rollback()
|
||||
return err
|
||||
}
|
||||
tx.stats.SpillTime += time.Since(startTime)
|
||||
|
||||
// Free the old root bucket.
|
||||
tx.meta.root.root = tx.root.root
|
||||
|
||||
opgid := tx.meta.pgid
|
||||
|
||||
// Free the freelist and allocate new pages for it. This will overestimate
|
||||
// the size of the freelist but not underestimate the size (which would be bad).
|
||||
tx.db.freelist.free(tx.meta.txid, tx.db.page(tx.meta.freelist))
|
||||
p, err := tx.allocate((tx.db.freelist.size() / tx.db.pageSize) + 1)
|
||||
if err != nil {
|
||||
tx.rollback()
|
||||
return err
|
||||
}
|
||||
if err := tx.db.freelist.write(p); err != nil {
|
||||
tx.rollback()
|
||||
return err
|
||||
}
|
||||
tx.meta.freelist = p.id
|
||||
|
||||
// If the high water mark has moved up then attempt to grow the database.
|
||||
if tx.meta.pgid > opgid {
|
||||
if err := tx.db.grow(int(tx.meta.pgid+1) * tx.db.pageSize); err != nil {
|
||||
tx.rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write dirty pages to disk.
|
||||
startTime = time.Now()
|
||||
if err := tx.write(); err != nil {
|
||||
tx.rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// If strict mode is enabled then perform a consistency check.
|
||||
// Only the first consistency error is reported in the panic.
|
||||
if tx.db.StrictMode {
|
||||
ch := tx.Check()
|
||||
var errs []string
|
||||
for {
|
||||
err, ok := <-ch
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
panic("check fail: " + strings.Join(errs, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
// Write meta to disk.
|
||||
if err := tx.writeMeta(); err != nil {
|
||||
tx.rollback()
|
||||
return err
|
||||
}
|
||||
tx.stats.WriteTime += time.Since(startTime)
|
||||
|
||||
// Finalize the transaction.
|
||||
tx.close()
|
||||
|
||||
// Execute commit handlers now that the locks have been removed.
|
||||
for _, fn := range tx.commitHandlers {
|
||||
fn()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback closes the transaction and ignores all previous updates. Read-only
|
||||
// transactions must be rolled back and not committed.
|
||||
func (tx *Tx) Rollback() error {
|
||||
_assert(!tx.managed, "managed tx rollback not allowed")
|
||||
if tx.db == nil {
|
||||
return ErrTxClosed
|
||||
}
|
||||
tx.rollback()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Tx) rollback() {
|
||||
if tx.db == nil {
|
||||
return
|
||||
}
|
||||
if tx.writable {
|
||||
tx.db.freelist.rollback(tx.meta.txid)
|
||||
tx.db.freelist.reload(tx.db.page(tx.db.meta().freelist))
|
||||
}
|
||||
tx.close()
|
||||
}
|
||||
|
||||
func (tx *Tx) close() {
|
||||
if tx.db == nil {
|
||||
return
|
||||
}
|
||||
if tx.writable {
|
||||
// Grab freelist stats.
|
||||
var freelistFreeN = tx.db.freelist.free_count()
|
||||
var freelistPendingN = tx.db.freelist.pending_count()
|
||||
var freelistAlloc = tx.db.freelist.size()
|
||||
|
||||
// Remove transaction ref & writer lock.
|
||||
tx.db.rwtx = nil
|
||||
tx.db.rwlock.Unlock()
|
||||
|
||||
// Merge statistics.
|
||||
tx.db.statlock.Lock()
|
||||
tx.db.stats.FreePageN = freelistFreeN
|
||||
tx.db.stats.PendingPageN = freelistPendingN
|
||||
tx.db.stats.FreeAlloc = (freelistFreeN + freelistPendingN) * tx.db.pageSize
|
||||
tx.db.stats.FreelistInuse = freelistAlloc
|
||||
tx.db.stats.TxStats.add(&tx.stats)
|
||||
tx.db.statlock.Unlock()
|
||||
} else {
|
||||
tx.db.removeTx(tx)
|
||||
}
|
||||
|
||||
// Clear all references.
|
||||
tx.db = nil
|
||||
tx.meta = nil
|
||||
tx.root = Bucket{tx: tx}
|
||||
tx.pages = nil
|
||||
}
|
||||
|
||||
// Copy writes the entire database to a writer.
|
||||
// This function exists for backwards compatibility. Use WriteTo() instead.
|
||||
func (tx *Tx) Copy(w io.Writer) error {
|
||||
_, err := tx.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteTo writes the entire database to a writer.
|
||||
// If err == nil then exactly tx.Size() bytes will be written into the writer.
|
||||
func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) {
|
||||
// Attempt to open reader with WriteFlag
|
||||
f, err := os.OpenFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
// Generate a meta page. We use the same page data for both meta pages.
|
||||
buf := make([]byte, tx.db.pageSize)
|
||||
page := (*page)(unsafe.Pointer(&buf[0]))
|
||||
page.flags = metaPageFlag
|
||||
*page.meta() = *tx.meta
|
||||
|
||||
// Write meta 0.
|
||||
page.id = 0
|
||||
page.meta().checksum = page.meta().sum64()
|
||||
nn, err := w.Write(buf)
|
||||
n += int64(nn)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("meta 0 copy: %s", err)
|
||||
}
|
||||
|
||||
// Write meta 1 with a lower transaction id.
|
||||
page.id = 1
|
||||
page.meta().txid -= 1
|
||||
page.meta().checksum = page.meta().sum64()
|
||||
nn, err = w.Write(buf)
|
||||
n += int64(nn)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("meta 1 copy: %s", err)
|
||||
}
|
||||
|
||||
// Move past the meta pages in the file.
|
||||
if _, err := f.Seek(int64(tx.db.pageSize*2), os.SEEK_SET); err != nil {
|
||||
return n, fmt.Errorf("seek: %s", err)
|
||||
}
|
||||
|
||||
// Copy data pages.
|
||||
wn, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2))
|
||||
n += wn
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, f.Close()
|
||||
}
|
||||
|
||||
// CopyFile copies the entire database to file at the given path.
|
||||
// A reader transaction is maintained during the copy so it is safe to continue
|
||||
// using the database while a copy is in progress.
|
||||
func (tx *Tx) CopyFile(path string, mode os.FileMode) error {
|
||||
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Copy(f)
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// Check performs several consistency checks on the database for this transaction.
|
||||
// An error is returned if any inconsistency is found.
|
||||
//
|
||||
// It can be safely run concurrently on a writable transaction. However, this
|
||||
// incurs a high cost for large databases and databases with a lot of subbuckets
|
||||
// because of caching. This overhead can be removed if running on a read-only
|
||||
// transaction, however, it is not safe to execute other writer transactions at
|
||||
// the same time.
|
||||
func (tx *Tx) Check() <-chan error {
|
||||
ch := make(chan error)
|
||||
go tx.check(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (tx *Tx) check(ch chan error) {
|
||||
// Check if any pages are double freed.
|
||||
freed := make(map[pgid]bool)
|
||||
for _, id := range tx.db.freelist.all() {
|
||||
if freed[id] {
|
||||
ch <- fmt.Errorf("page %d: already freed", id)
|
||||
}
|
||||
freed[id] = true
|
||||
}
|
||||
|
||||
// Track every reachable page.
|
||||
reachable := make(map[pgid]*page)
|
||||
reachable[0] = tx.page(0) // meta0
|
||||
reachable[1] = tx.page(1) // meta1
|
||||
for i := uint32(0); i <= tx.page(tx.meta.freelist).overflow; i++ {
|
||||
reachable[tx.meta.freelist+pgid(i)] = tx.page(tx.meta.freelist)
|
||||
}
|
||||
|
||||
// Recursively check buckets.
|
||||
tx.checkBucket(&tx.root, reachable, freed, ch)
|
||||
|
||||
// Ensure all pages below high water mark are either reachable or freed.
|
||||
for i := pgid(0); i < tx.meta.pgid; i++ {
|
||||
_, isReachable := reachable[i]
|
||||
if !isReachable && !freed[i] {
|
||||
ch <- fmt.Errorf("page %d: unreachable unfreed", int(i))
|
||||
}
|
||||
}
|
||||
|
||||
// Close the channel to signal completion.
|
||||
close(ch)
|
||||
}
|
||||
|
||||
func (tx *Tx) checkBucket(b *Bucket, reachable map[pgid]*page, freed map[pgid]bool, ch chan error) {
|
||||
// Ignore inline buckets.
|
||||
if b.root == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Check every page used by this bucket.
|
||||
b.tx.forEachPage(b.root, 0, func(p *page, _ int) {
|
||||
if p.id > tx.meta.pgid {
|
||||
ch <- fmt.Errorf("page %d: out of bounds: %d", int(p.id), int(b.tx.meta.pgid))
|
||||
}
|
||||
|
||||
// Ensure each page is only referenced once.
|
||||
for i := pgid(0); i <= pgid(p.overflow); i++ {
|
||||
var id = p.id + i
|
||||
if _, ok := reachable[id]; ok {
|
||||
ch <- fmt.Errorf("page %d: multiple references", int(id))
|
||||
}
|
||||
reachable[id] = p
|
||||
}
|
||||
|
||||
// We should only encounter un-freed leaf and branch pages.
|
||||
if freed[p.id] {
|
||||
ch <- fmt.Errorf("page %d: reachable freed", int(p.id))
|
||||
} else if (p.flags&branchPageFlag) == 0 && (p.flags&leafPageFlag) == 0 {
|
||||
ch <- fmt.Errorf("page %d: invalid type: %s", int(p.id), p.typ())
|
||||
}
|
||||
})
|
||||
|
||||
// Check each bucket within this bucket.
|
||||
_ = b.ForEach(func(k, v []byte) error {
|
||||
if child := b.Bucket(k); child != nil {
|
||||
tx.checkBucket(child, reachable, freed, ch)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// allocate returns a contiguous block of memory starting at a given page.
|
||||
func (tx *Tx) allocate(count int) (*page, error) {
|
||||
p, err := tx.db.allocate(count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save to our page cache.
|
||||
tx.pages[p.id] = p
|
||||
|
||||
// Update statistics.
|
||||
tx.stats.PageCount++
|
||||
tx.stats.PageAlloc += count * tx.db.pageSize
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// write writes any dirty pages to disk.
|
||||
func (tx *Tx) write() error {
|
||||
// Sort pages by id.
|
||||
pages := make(pages, 0, len(tx.pages))
|
||||
for _, p := range tx.pages {
|
||||
pages = append(pages, p)
|
||||
}
|
||||
// Clear out page cache early.
|
||||
tx.pages = make(map[pgid]*page)
|
||||
sort.Sort(pages)
|
||||
|
||||
// Write pages to disk in order.
|
||||
for _, p := range pages {
|
||||
size := (int(p.overflow) + 1) * tx.db.pageSize
|
||||
offset := int64(p.id) * int64(tx.db.pageSize)
|
||||
|
||||
// Write out page in "max allocation" sized chunks.
|
||||
ptr := (*[maxAllocSize]byte)(unsafe.Pointer(p))
|
||||
for {
|
||||
// Limit our write to our max allocation size.
|
||||
sz := size
|
||||
if sz > maxAllocSize-1 {
|
||||
sz = maxAllocSize - 1
|
||||
}
|
||||
|
||||
// Write chunk to disk.
|
||||
buf := ptr[:sz]
|
||||
if _, err := tx.db.ops.writeAt(buf, offset); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update statistics.
|
||||
tx.stats.Write++
|
||||
|
||||
// Exit inner for loop if we've written all the chunks.
|
||||
size -= sz
|
||||
if size == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Otherwise move offset forward and move pointer to next chunk.
|
||||
offset += int64(sz)
|
||||
ptr = (*[maxAllocSize]byte)(unsafe.Pointer(&ptr[sz]))
|
||||
}
|
||||
}
|
||||
|
||||
// Ignore file sync if flag is set on DB.
|
||||
if !tx.db.NoSync || IgnoreNoSync {
|
||||
if err := fdatasync(tx.db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Put small pages back to page pool.
|
||||
for _, p := range pages {
|
||||
// Ignore page sizes over 1 page.
|
||||
// These are allocated using make() instead of the page pool.
|
||||
if int(p.overflow) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
buf := (*[maxAllocSize]byte)(unsafe.Pointer(p))[:tx.db.pageSize]
|
||||
|
||||
// See https://go.googlesource.com/go/+/f03c9202c43e0abb130669852082117ca50aa9b1
|
||||
for i := range buf {
|
||||
buf[i] = 0
|
||||
}
|
||||
tx.db.pagePool.Put(buf)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeMeta writes the meta to the disk.
|
||||
func (tx *Tx) writeMeta() error {
|
||||
// Create a temporary buffer for the meta page.
|
||||
buf := make([]byte, tx.db.pageSize)
|
||||
p := tx.db.pageInBuffer(buf, 0)
|
||||
tx.meta.write(p)
|
||||
|
||||
// Write the meta page to file.
|
||||
if _, err := tx.db.ops.writeAt(buf, int64(p.id)*int64(tx.db.pageSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
if !tx.db.NoSync || IgnoreNoSync {
|
||||
if err := fdatasync(tx.db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update statistics.
|
||||
tx.stats.Write++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// page returns a reference to the page with a given id.
|
||||
// If page has been written to then a temporary buffered page is returned.
|
||||
func (tx *Tx) page(id pgid) *page {
|
||||
// Check the dirty pages first.
|
||||
if tx.pages != nil {
|
||||
if p, ok := tx.pages[id]; ok {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise return directly from the mmap.
|
||||
return tx.db.page(id)
|
||||
}
|
||||
|
||||
// forEachPage iterates over every page within a given page and executes a function.
|
||||
func (tx *Tx) forEachPage(pgid pgid, depth int, fn func(*page, int)) {
|
||||
p := tx.page(pgid)
|
||||
|
||||
// Execute function.
|
||||
fn(p, depth)
|
||||
|
||||
// Recursively loop over children.
|
||||
if (p.flags & branchPageFlag) != 0 {
|
||||
for i := 0; i < int(p.count); i++ {
|
||||
elem := p.branchPageElement(uint16(i))
|
||||
tx.forEachPage(elem.pgid, depth+1, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Page returns page information for a given page number.
|
||||
// This is only safe for concurrent use when used by a writable transaction.
|
||||
func (tx *Tx) Page(id int) (*PageInfo, error) {
|
||||
if tx.db == nil {
|
||||
return nil, ErrTxClosed
|
||||
} else if pgid(id) >= tx.meta.pgid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Build the page info.
|
||||
p := tx.db.page(pgid(id))
|
||||
info := &PageInfo{
|
||||
ID: id,
|
||||
Count: int(p.count),
|
||||
OverflowCount: int(p.overflow),
|
||||
}
|
||||
|
||||
// Determine the type (or if it's free).
|
||||
if tx.db.freelist.freed(pgid(id)) {
|
||||
info.Type = "free"
|
||||
} else {
|
||||
info.Type = p.typ()
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// TxStats represents statistics about the actions performed by the transaction.
|
||||
type TxStats struct {
|
||||
// Page statistics.
|
||||
PageCount int // number of page allocations
|
||||
PageAlloc int // total bytes allocated
|
||||
|
||||
// Cursor statistics.
|
||||
CursorCount int // number of cursors created
|
||||
|
||||
// Node statistics
|
||||
NodeCount int // number of node allocations
|
||||
NodeDeref int // number of node dereferences
|
||||
|
||||
// Rebalance statistics.
|
||||
Rebalance int // number of node rebalances
|
||||
RebalanceTime time.Duration // total time spent rebalancing
|
||||
|
||||
// Split/Spill statistics.
|
||||
Split int // number of nodes split
|
||||
Spill int // number of nodes spilled
|
||||
SpillTime time.Duration // total time spent spilling
|
||||
|
||||
// Write statistics.
|
||||
Write int // number of writes performed
|
||||
WriteTime time.Duration // total time spent writing to disk
|
||||
}
|
||||
|
||||
func (s *TxStats) add(other *TxStats) {
|
||||
s.PageCount += other.PageCount
|
||||
s.PageAlloc += other.PageAlloc
|
||||
s.CursorCount += other.CursorCount
|
||||
s.NodeCount += other.NodeCount
|
||||
s.NodeDeref += other.NodeDeref
|
||||
s.Rebalance += other.Rebalance
|
||||
s.RebalanceTime += other.RebalanceTime
|
||||
s.Split += other.Split
|
||||
s.Spill += other.Spill
|
||||
s.SpillTime += other.SpillTime
|
||||
s.Write += other.Write
|
||||
s.WriteTime += other.WriteTime
|
||||
}
|
||||
|
||||
// Sub calculates and returns the difference between two sets of transaction stats.
|
||||
// This is useful when obtaining stats at two different points and time and
|
||||
// you need the performance counters that occurred within that time span.
|
||||
func (s *TxStats) Sub(other *TxStats) TxStats {
|
||||
var diff TxStats
|
||||
diff.PageCount = s.PageCount - other.PageCount
|
||||
diff.PageAlloc = s.PageAlloc - other.PageAlloc
|
||||
diff.CursorCount = s.CursorCount - other.CursorCount
|
||||
diff.NodeCount = s.NodeCount - other.NodeCount
|
||||
diff.NodeDeref = s.NodeDeref - other.NodeDeref
|
||||
diff.Rebalance = s.Rebalance - other.Rebalance
|
||||
diff.RebalanceTime = s.RebalanceTime - other.RebalanceTime
|
||||
diff.Split = s.Split - other.Split
|
||||
diff.Spill = s.Spill - other.Spill
|
||||
diff.SpillTime = s.SpillTime - other.SpillTime
|
||||
diff.Write = s.Write - other.Write
|
||||
diff.WriteTime = s.WriteTime - other.WriteTime
|
||||
return diff
|
||||
}
|
|
@ -0,0 +1,716 @@
|
|||
package bolt_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
// Ensure that committing a closed transaction returns an error.
|
||||
func TestTx_Commit_ErrTxClosed(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucket([]byte("foo")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != bolt.ErrTxClosed {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that rolling back a closed transaction returns an error.
|
||||
func TestTx_Rollback_ErrTxClosed(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tx.Rollback(); err != bolt.ErrTxClosed {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that committing a read-only transaction returns an error.
|
||||
func TestTx_Commit_ErrTxNotWritable(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
tx, err := db.Begin(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != bolt.ErrTxNotWritable {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a transaction can retrieve a cursor on the root bucket.
|
||||
func TestTx_Cursor(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if _, err := tx.CreateBucket([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucket([]byte("woojits")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := tx.Cursor()
|
||||
if k, v := c.First(); !bytes.Equal(k, []byte("widgets")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
if k, v := c.Next(); !bytes.Equal(k, []byte("woojits")) {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
if k, v := c.Next(); k != nil {
|
||||
t.Fatalf("unexpected key: %v", k)
|
||||
} else if v != nil {
|
||||
t.Fatalf("unexpected value: %v", k)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that creating a bucket with a read-only transaction returns an error.
|
||||
func TestTx_CreateBucket_ErrTxNotWritable(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucket([]byte("foo"))
|
||||
if err != bolt.ErrTxNotWritable {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that creating a bucket on a closed transaction returns an error.
|
||||
func TestTx_CreateBucket_ErrTxClosed(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucket([]byte("foo")); err != bolt.ErrTxClosed {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx can retrieve a bucket.
|
||||
func TestTx_Bucket(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if _, err := tx.CreateBucket([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tx.Bucket([]byte("widgets")) == nil {
|
||||
t.Fatal("expected bucket")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a Tx retrieving a non-existent key returns nil.
|
||||
func TestTx_Get_NotFound(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if b.Get([]byte("no_such_key")) != nil {
|
||||
t.Fatal("expected nil value")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a bucket can be created and retrieved.
|
||||
func TestTx_CreateBucket(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
// Create a bucket.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if b == nil {
|
||||
t.Fatal("expected bucket")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the bucket through a separate transaction.
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
if tx.Bucket([]byte("widgets")) == nil {
|
||||
t.Fatal("expected bucket")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a bucket can be created if it doesn't already exist.
|
||||
func TestTx_CreateBucketIfNotExists(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
// Create bucket.
|
||||
if b, err := tx.CreateBucketIfNotExists([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if b == nil {
|
||||
t.Fatal("expected bucket")
|
||||
}
|
||||
|
||||
// Create bucket again.
|
||||
if b, err := tx.CreateBucketIfNotExists([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if b == nil {
|
||||
t.Fatal("expected bucket")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the bucket through a separate transaction.
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
if tx.Bucket([]byte("widgets")) == nil {
|
||||
t.Fatal("expected bucket")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure transaction returns an error if creating an unnamed bucket.
|
||||
func TestTx_CreateBucketIfNotExists_ErrBucketNameRequired(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if _, err := tx.CreateBucketIfNotExists([]byte{}); err != bolt.ErrBucketNameRequired {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucketIfNotExists(nil); err != bolt.ErrBucketNameRequired {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a bucket cannot be created twice.
|
||||
func TestTx_CreateBucket_ErrBucketExists(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
// Create a bucket.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if _, err := tx.CreateBucket([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create the same bucket again.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if _, err := tx.CreateBucket([]byte("widgets")); err != bolt.ErrBucketExists {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a bucket is created with a non-blank name.
|
||||
func TestTx_CreateBucket_ErrBucketNameRequired(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if _, err := tx.CreateBucket(nil); err != bolt.ErrBucketNameRequired {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a bucket can be deleted.
|
||||
func TestTx_DeleteBucket(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
// Create a bucket and add a value.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete the bucket and make sure we can't get the value.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if err := tx.DeleteBucket([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tx.Bucket([]byte("widgets")) != nil {
|
||||
t.Fatal("unexpected bucket")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
// Create the bucket again and make sure there's not a phantom value.
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v := b.Get([]byte("foo")); v != nil {
|
||||
t.Fatalf("unexpected phantom value: %v", v)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that deleting a bucket on a closed transaction returns an error.
|
||||
func TestTx_DeleteBucket_ErrTxClosed(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tx.DeleteBucket([]byte("foo")); err != bolt.ErrTxClosed {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that deleting a bucket with a read-only transaction returns an error.
|
||||
func TestTx_DeleteBucket_ReadOnly(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
if err := tx.DeleteBucket([]byte("foo")); err != bolt.ErrTxNotWritable {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that nothing happens when deleting a bucket that doesn't exist.
|
||||
func TestTx_DeleteBucket_NotFound(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
if err := tx.DeleteBucket([]byte("widgets")); err != bolt.ErrBucketNotFound {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that no error is returned when a tx.ForEach function does not return
|
||||
// an error.
|
||||
func TestTx_ForEach_NoError(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error {
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that an error is returned when a tx.ForEach function returns an error.
|
||||
func TestTx_ForEach_WithError(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
marker := errors.New("marker")
|
||||
if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error {
|
||||
return marker
|
||||
}); err != marker {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that Tx commit handlers are called after a transaction successfully commits.
|
||||
func TestTx_OnCommit(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
var x int
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
tx.OnCommit(func() { x += 1 })
|
||||
tx.OnCommit(func() { x += 2 })
|
||||
if _, err := tx.CreateBucket([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if x != 3 {
|
||||
t.Fatalf("unexpected x: %d", x)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that Tx commit handlers are NOT called after a transaction rolls back.
|
||||
func TestTx_OnCommit_Rollback(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
var x int
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
tx.OnCommit(func() { x += 1 })
|
||||
tx.OnCommit(func() { x += 2 })
|
||||
if _, err := tx.CreateBucket([]byte("widgets")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return errors.New("rollback this commit")
|
||||
}); err == nil || err.Error() != "rollback this commit" {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
} else if x != 0 {
|
||||
t.Fatalf("unexpected x: %d", x)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the database can be copied to a file path.
|
||||
func TestTx_CopyFile(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
|
||||
path := tempfile()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("baz"), []byte("bat")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
return tx.CopyFile(path, 0600)
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db2, err := bolt.Open(path, 0600, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db2.View(func(tx *bolt.Tx) error {
|
||||
if v := tx.Bucket([]byte("widgets")).Get([]byte("foo")); !bytes.Equal(v, []byte("bar")) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
if v := tx.Bucket([]byte("widgets")).Get([]byte("baz")); !bytes.Equal(v, []byte("bat")) {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type failWriterError struct{}
|
||||
|
||||
func (failWriterError) Error() string {
|
||||
return "error injected for tests"
|
||||
}
|
||||
|
||||
type failWriter struct {
|
||||
// fail after this many bytes
|
||||
After int
|
||||
}
|
||||
|
||||
func (f *failWriter) Write(p []byte) (n int, err error) {
|
||||
n = len(p)
|
||||
if n > f.After {
|
||||
n = f.After
|
||||
err = failWriterError{}
|
||||
}
|
||||
f.After -= n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Ensure that Copy handles write errors right.
|
||||
func TestTx_CopyFile_Error_Meta(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("baz"), []byte("bat")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
return tx.Copy(&failWriter{})
|
||||
}); err == nil || err.Error() != "meta 0 copy: error injected for tests" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that Copy handles write errors right.
|
||||
func TestTx_CopyFile_Error_Normal(t *testing.T) {
|
||||
db := MustOpenDB()
|
||||
defer db.MustClose()
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Put([]byte("baz"), []byte("bat")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
return tx.Copy(&failWriter{3 * db.Info().PageSize})
|
||||
}); err == nil || err.Error() != "error injected for tests" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleTx_Rollback() {
|
||||
// Open the database.
|
||||
db, err := bolt.Open(tempfile(), 0666, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove(db.Path())
|
||||
|
||||
// Create a bucket.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucket([]byte("widgets"))
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Set a value for a key.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar"))
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Update the key but rollback the transaction so it never saves.
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
b := tx.Bucket([]byte("widgets"))
|
||||
if err := b.Put([]byte("foo"), []byte("baz")); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := tx.Rollback(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Ensure that our original value is still set.
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
value := tx.Bucket([]byte("widgets")).Get([]byte("foo"))
|
||||
fmt.Printf("The value for 'foo' is still: %s\n", value)
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Close database to release file lock.
|
||||
if err := db.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// The value for 'foo' is still: bar
|
||||
}
|
||||
|
||||
func ExampleTx_CopyFile() {
|
||||
// Open the database.
|
||||
db, err := bolt.Open(tempfile(), 0666, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove(db.Path())
|
||||
|
||||
// Create a bucket and a key.
|
||||
if err := db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.Put([]byte("foo"), []byte("bar")); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Copy the database to another file.
|
||||
toFile := tempfile()
|
||||
if err := db.View(func(tx *bolt.Tx) error {
|
||||
return tx.CopyFile(toFile, 0666)
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.Remove(toFile)
|
||||
|
||||
// Open the cloned database.
|
||||
db2, err := bolt.Open(toFile, 0666, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Ensure that the key exists in the copy.
|
||||
if err := db2.View(func(tx *bolt.Tx) error {
|
||||
value := tx.Bucket([]byte("widgets")).Get([]byte("foo"))
|
||||
fmt.Printf("The value for 'foo' in the clone is: %s\n", value)
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Close database to release file lock.
|
||||
if err := db.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db2.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// The value for 'foo' in the clone is: bar
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
sudo: false
|
||||
language: go
|
||||
go:
|
||||
- 1.1
|
||||
- 1.2
|
||||
- 1.3
|
||||
- 1.4
|
||||
- 1.5
|
||||
- 1.6
|
||||
- tip
|
|
@ -0,0 +1,24 @@
|
|||
Copyright (c) 2013 Julien Schmidt. All rights reserved.
|
||||
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* The names of the contributors may not be used to endorse or promote
|
||||
products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL JULIEN SCHMIDT BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,275 @@
|
|||
# HttpRouter [![GoDoc](https://godoc.org/github.com/bouk/httprouter?status.svg)](http://godoc.org/github.com/bouk/httprouter)
|
||||
|
||||
HttpRouter is a lightweight high performance HTTP request router (also called *multiplexer* or just *mux* for short) for [Go](https://golang.org/).
|
||||
|
||||
In contrast to the [default mux][http.ServeMux] of Go's `net/http` package, this router supports variables in the routing pattern and matches against the request method. It also scales better.
|
||||
|
||||
The router is optimized for high performance and a small memory footprint. It scales well even with very long paths and a large number of routes. A compressing dynamic trie (radix tree) structure is used for efficient matching.
|
||||
|
||||
## Features
|
||||
|
||||
**Only explicit matches:** With other routers, like [`http.ServeMux`][http.ServeMux], a requested URL path could match multiple patterns. Therefore they have some awkward pattern priority rules, like *longest match* or *first registered, first matched*. By design of this router, a request can only match exactly one or no route. As a result, there are also no unintended matches, which makes it great for SEO and improves the user experience.
|
||||
|
||||
**Stop caring about trailing slashes:** Choose the URL style you like, the router automatically redirects the client if a trailing slash is missing or if there is one extra. Of course it only does so, if the new path has a handler. If you don't like it, you can [turn off this behavior](https://godoc.org/github.com/bouk/httprouter#Router.RedirectTrailingSlash).
|
||||
|
||||
**Path auto-correction:** Besides detecting the missing or additional trailing slash at no extra cost, the router can also fix wrong cases and remove superfluous path elements (like `../` or `//`). Is [CAPTAIN CAPS LOCK](http://www.urbandictionary.com/define.php?term=Captain+Caps+Lock) one of your users? HttpRouter can help him by making a case-insensitive look-up and redirecting him to the correct URL.
|
||||
|
||||
**Parameters in your routing pattern:** Stop parsing the requested URL path, just give the path segment a name and the router delivers the dynamic value to you. Because of the design of the router, path parameters are very cheap.
|
||||
|
||||
**Zero Garbage:** The matching and dispatching process generates zero bytes of garbage. In fact, the only heap allocations that are made, is by building the slice of the key-value pairs for path parameters. If the request path contains no parameters, not a single heap allocation is necessary.
|
||||
|
||||
**Best Performance:** [Benchmarks speak for themselves][benchmark]. See below for technical details of the implementation.
|
||||
|
||||
**No more server crashes:** You can set a [Panic handler][Router.PanicHandler] to deal with panics occurring during handling a HTTP request. The router then recovers and lets the `PanicHandler` log what happened and deliver a nice error page.
|
||||
|
||||
**Perfect for APIs:** The router design encourages to build sensible, hierarchical RESTful APIs. Moreover it has builtin native support for [OPTIONS requests](http://zacstewart.com/2012/04/14/http-options-method.html) and `405 Method Not Allowed` replies.
|
||||
|
||||
Of course you can also set **custom [`NotFound`][Router.NotFound] and [`MethodNotAllowed`](https://godoc.org/github.com/bouk/httprouter#Router.MethodNotAllowed) handlers** and [**serve static files**][Router.ServeFiles].
|
||||
|
||||
## Usage
|
||||
|
||||
This is just a quick introduction, view the [GoDoc](http://godoc.org/github.com/bouk/httprouter) for details.
|
||||
|
||||
Let's start with a trivial example:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/bouk/httprouter"
|
||||
"net/http"
|
||||
"log"
|
||||
)
|
||||
|
||||
func Index(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, "Welcome!\n")
|
||||
}
|
||||
|
||||
func Hello(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "hello, %s!\n", httprouter.GetParam(r, "name"))
|
||||
}
|
||||
|
||||
func main() {
|
||||
router := httprouter.New()
|
||||
router.GET("/", Index)
|
||||
router.GET("/hello/:name", Hello)
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
```
|
||||
|
||||
### Named parameters
|
||||
|
||||
As you can see, `:name` is a *named parameter*. The values are accessible via `httprouter.GetParam`, which returns the value of the given key. Additionally you can access the `httprouter.Params` slice through `httprouter.GetParams`, which allows you to access the parameters by index.
|
||||
Named parameters only match a single path segment:
|
||||
|
||||
```
|
||||
Pattern: /user/:user
|
||||
|
||||
/user/gordon match
|
||||
/user/you match
|
||||
/user/gordon/profile no match
|
||||
/user/ no match
|
||||
```
|
||||
|
||||
**Note:** Since this router has only explicit matches, you can not register static routes and parameters for the same path segment. For example you can not register the patterns `/user/new` and `/user/:user` for the same request method at the same time. The routing of different request methods is independent from each other.
|
||||
|
||||
### Catch-All parameters
|
||||
|
||||
The second type are *catch-all* parameters and have the form `*name`. Like the name suggests, they match everything. Therefore they must always be at the **end** of the pattern:
|
||||
|
||||
```
|
||||
Pattern: /src/*filepath
|
||||
|
||||
/src/ match
|
||||
/src/somefile.go match
|
||||
/src/subdir/somefile.go match
|
||||
```
|
||||
|
||||
## How does it work?
|
||||
|
||||
The router relies on a tree structure which makes heavy use of *common prefixes*, it is basically a *compact* [*prefix tree*](https://en.wikipedia.org/wiki/Trie) (or just [*Radix tree*](https://en.wikipedia.org/wiki/Radix_tree)). Nodes with a common prefix also share a common parent. Here is a short example what the routing tree for the `GET` request method could look like:
|
||||
|
||||
```
|
||||
Priority Path Handle
|
||||
9 \ *<1>
|
||||
3 ├s nil
|
||||
2 |├earch\ *<2>
|
||||
1 |└upport\ *<3>
|
||||
2 ├blog\ *<4>
|
||||
1 | └:post nil
|
||||
1 | └\ *<5>
|
||||
2 ├about-us\ *<6>
|
||||
1 | └team\ *<7>
|
||||
1 └contact\ *<8>
|
||||
```
|
||||
|
||||
Every `*<num>` represents the memory address of a handler function (a pointer). If you follow a path trough the tree from the root to the leaf, you get the complete route path, e.g `\blog\:post\`, where `:post` is just a placeholder ([*parameter*](#named-parameters)) for an actual post name. Unlike hash-maps, a tree structure also allows us to use dynamic parts like the `:post` parameter, since we actually match against the routing patterns instead of just comparing hashes. [As benchmarks show][benchmark], this works very well and efficient.
|
||||
|
||||
Since URL paths have a hierarchical structure and make use only of a limited set of characters (byte values), it is very likely that there are a lot of common prefixes. This allows us to easily reduce the routing into ever smaller problems. Moreover the router manages a separate tree for every request method. For one thing it is more space efficient than holding a method->handle map in every single node, for another thing is also allows us to greatly reduce the routing problem before even starting the look-up in the prefix-tree.
|
||||
|
||||
For even better scalability, the child nodes on each tree level are ordered by priority, where the priority is just the number of handles registered in sub nodes (children, grandchildren, and so on..). This helps in two ways:
|
||||
|
||||
1. Nodes which are part of the most routing paths are evaluated first. This helps to make as much routes as possible to be reachable as fast as possible.
|
||||
2. It is some sort of cost compensation. The longest reachable path (highest cost) can always be evaluated first. The following scheme visualizes the tree structure. Nodes are evaluated from top to bottom and from left to right.
|
||||
|
||||
```
|
||||
├------------
|
||||
├---------
|
||||
├-----
|
||||
├----
|
||||
├--
|
||||
├--
|
||||
└-
|
||||
```
|
||||
|
||||
## Why doesn't this work with `http.Handler`?
|
||||
|
||||
**It does!** The router itself implements the `http.Handler` interface. Moreover the router provides convenient [adapters for `http.Handler`][Router.Handler]s and [`http.HandlerFunc`][Router.HandlerFunc]s which allows them to be used as a [`httprouter.Handle`][Router.Handle] when registering a route. The only disadvantage is, that no parameter values can be retrieved when a `http.Handler` or `http.HandlerFunc` is used, since there is no efficient way to pass the values with the existing function parameters. Therefore [`httprouter.Handle`][Router.Handle] has a third function parameter.
|
||||
|
||||
Just try it out for yourself, the usage of HttpRouter is very straightforward. The package is compact and minimalistic, but also probably one of the easiest routers to set up.
|
||||
|
||||
## Where can I find Middleware *X*?
|
||||
|
||||
This package just provides a very efficient request router with a few extra features. The router is just a [`http.Handler`][http.Handler], you can chain any http.Handler compatible middleware before the router, for example the [Gorilla handlers](http://www.gorillatoolkit.org/pkg/handlers). Or you could [just write your own](https://justinas.org/writing-http-middleware-in-go/), it's very easy!
|
||||
|
||||
Alternatively, you could try [a web framework based on HttpRouter](#web-frameworks-based-on-httprouter).
|
||||
|
||||
### Multi-domain / Sub-domains
|
||||
|
||||
Here is a quick example: Does your server serve multiple domains / hosts?
|
||||
You want to use sub-domains?
|
||||
Define a router per host!
|
||||
|
||||
```go
|
||||
// We need an object that implements the http.Handler interface.
|
||||
// Therefore we need a type for which we implement the ServeHTTP method.
|
||||
// We just use a map here, in which we map host names (with port) to http.Handlers
|
||||
type HostSwitch map[string]http.Handler
|
||||
|
||||
// Implement the ServerHTTP method on our new type
|
||||
func (hs HostSwitch) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if a http.Handler is registered for the given host.
|
||||
// If yes, use it to handle the request.
|
||||
if handler := hs[r.Host]; handler != nil {
|
||||
handler.ServeHTTP(w, r)
|
||||
} else {
|
||||
// Handle host names for wich no handler is registered
|
||||
http.Error(w, "Forbidden", 403) // Or Redirect?
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Initialize a router as usual
|
||||
router := httprouter.New()
|
||||
router.GET("/", Index)
|
||||
router.GET("/hello/:name", Hello)
|
||||
|
||||
// Make a new HostSwitch and insert the router (our http handler)
|
||||
// for example.com and port 12345
|
||||
hs := make(HostSwitch)
|
||||
hs["example.com:12345"] = router
|
||||
|
||||
// Use the HostSwitch to listen and serve on port 12345
|
||||
log.Fatal(http.ListenAndServe(":12345", hs))
|
||||
}
|
||||
```
|
||||
|
||||
### Basic Authentication
|
||||
|
||||
Another quick example: Basic Authentication (RFC 2617) for handles:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bouk/httprouter"
|
||||
)
|
||||
|
||||
func BasicAuth(h httprouter.Handle, requiredUser, requiredPassword string) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the Basic Authentication credentials
|
||||
user, password, hasAuth := r.BasicAuth()
|
||||
|
||||
if hasAuth && user == requiredUser && password == requiredPassword {
|
||||
// Delegate request to the given handle
|
||||
h(w, r)
|
||||
} else {
|
||||
// Request Basic Authentication otherwise
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=Restricted")
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
fmt.Fprint(w, "Not protected!\n")
|
||||
}
|
||||
|
||||
func Protected(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
fmt.Fprint(w, "Protected!\n")
|
||||
}
|
||||
|
||||
func main() {
|
||||
user := "gordon"
|
||||
pass := "secret!"
|
||||
|
||||
router := httprouter.New()
|
||||
router.GET("/", Index)
|
||||
router.GET("/protected/", BasicAuth(Protected, user, pass))
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
```
|
||||
|
||||
## Chaining with the NotFound handler
|
||||
|
||||
**NOTE: It might be required to set [`Router.HandleMethodNotAllowed`][Router.HandleMethodNotAllowed] to `false` to avoid problems.**
|
||||
|
||||
You can use another [`http.Handler`][http.Handler], for example another router, to handle requests which could not be matched by this router by using the [`Router.NotFound`][Router.NotFound] handler. This allows chaining.
|
||||
|
||||
### Static files
|
||||
|
||||
The `NotFound` handler can for example be used to serve static files from the root path `/` (like an `index.html` file along with other assets):
|
||||
|
||||
```go
|
||||
// Serve static files from the ./public directory
|
||||
router.NotFound = http.FileServer(http.Dir("public"))
|
||||
```
|
||||
|
||||
But this approach sidesteps the strict core rules of this router to avoid routing problems. A cleaner approach is to use a distinct sub-path for serving files, like `/static/*filepath` or `/files/*filepath`.
|
||||
|
||||
## Web Frameworks based on HttpRouter
|
||||
|
||||
If the HttpRouter is a bit too minimalistic for you, you might try one of the following more high-level 3rd-party web frameworks building upon the HttpRouter package:
|
||||
|
||||
* [Ace](https://github.com/plimble/ace): Blazing fast Go Web Framework
|
||||
* [api2go](https://github.com/manyminds/api2go): A JSON API Implementation for Go
|
||||
* [Gin](https://github.com/gin-gonic/gin): Features a martini-like API with much better performance
|
||||
* [Goat](https://github.com/bahlo/goat): A minimalistic REST API server in Go
|
||||
* [Hikaru](https://github.com/najeira/hikaru): Supports standalone and Google AppEngine
|
||||
* [Hitch](https://github.com/nbio/hitch): Hitch ties httprouter, [httpcontext](https://github.com/nbio/httpcontext), and middleware up in a bow
|
||||
* [httpway](https://github.com/corneldamian/httpway): Simple middleware extension with context for httprouter and a server with gracefully shutdown support
|
||||
* [kami](https://github.com/guregu/kami): A tiny web framework using x/net/context
|
||||
* [Medeina](https://github.com/imdario/medeina): Inspired by Ruby's Roda and Cuba
|
||||
* [Neko](https://github.com/rocwong/neko): A lightweight web application framework for Golang
|
||||
* [River](https://github.com/abiosoft/river): River is a simple and lightweight REST server
|
||||
* [Roxanna](https://github.com/iamthemuffinman/Roxanna): An amalgamation of httprouter, better logging, and hot reload
|
||||
* [siesta](https://github.com/VividCortex/siesta): Composable HTTP handlers with contexts
|
||||
* [xmux](https://github.com/rs/xmux): xmux is a httprouter fork on top of xhandler (net/context aware)
|
||||
|
||||
[benchmark]: <https://github.com/bouk/go-http-routing-benchmark>
|
||||
[http.Handler]: <https://golang.org/pkg/net/http/#Handler
|
||||
[http.ServeMux]: <https://golang.org/pkg/net/http/#ServeMux>
|
||||
[Router.Handle]: <https://godoc.org/github.com/bouk/httprouter#Router.Handle>
|
||||
[Router.HandleMethodNotAllowed]: <https://godoc.org/github.com/bouk/httprouter#Router.HandleMethodNotAllowed>
|
||||
[Router.Handler]: <https://godoc.org/github.com/bouk/httprouter#Router.Handler>
|
||||
[Router.HandlerFunc]: <https://godoc.org/github.com/bouk/httprouter#Router.HandlerFunc>
|
||||
[Router.NotFound]: <https://godoc.org/github.com/bouk/httprouter#Router.NotFound>
|
||||
[Router.PanicHandler]: <https://godoc.org/github.com/bouk/httprouter#Router.PanicHandler>
|
||||
[Router.ServeFiles]: <https://godoc.org/github.com/bouk/httprouter#Router.ServeFiles>
|
|
@ -0,0 +1,57 @@
|
|||
package httprouter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Param is a single URL parameter, consisting of a key and a value.
|
||||
type Param struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Params is a Param-slice, as returned by the router.
|
||||
// The slice is ordered, the first URL parameter is also the first slice value.
|
||||
// It is therefore safe to read values by the index.
|
||||
type Params []Param
|
||||
|
||||
// ByName returns the value of the first Param which key matches the given name.
|
||||
// If no matching Param is found, an empty string is returned.
|
||||
func (ps Params) ByName(name string) string {
|
||||
for i := range ps {
|
||||
if ps[i].Key == name {
|
||||
return ps[i].Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type paramsContextKey struct{}
|
||||
|
||||
// WithParams returns a new Context that contains ps.
|
||||
func WithParams(ctx context.Context, ps Params) context.Context {
|
||||
return context.WithValue(ctx, paramsContextKey{}, ps)
|
||||
}
|
||||
|
||||
// GetParams returns the Params from the Context of the Request.
|
||||
func GetParams(req *http.Request) Params {
|
||||
return GetParamsFromContext(req.Context())
|
||||
}
|
||||
|
||||
// GetParamsFromContext returns the Params from the Context.
|
||||
func GetParamsFromContext(ctx context.Context) Params {
|
||||
ps, _ := ctx.Value(paramsContextKey{}).(Params)
|
||||
return ps
|
||||
}
|
||||
|
||||
// GetParam returns the value of the first Param which key matches in the Context
|
||||
// of the Request.
|
||||
func GetParam(req *http.Request, key string) string {
|
||||
return GetParamFromContext(req.Context(), key)
|
||||
}
|
||||
|
||||
// GetParam returns the value of the first Param which key matches in the Context.
|
||||
func GetParamFromContext(ctx context.Context, key string) string {
|
||||
return GetParamsFromContext(ctx).ByName(key)
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package httprouter
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParams(t *testing.T) {
|
||||
ps := Params{
|
||||
Param{"param1", "value1"},
|
||||
Param{"param2", "value2"},
|
||||
Param{"param3", "value3"},
|
||||
}
|
||||
for i := range ps {
|
||||
if val := ps.ByName(ps[i].Key); val != ps[i].Value {
|
||||
t.Errorf("Wrong value for %s: Got %s; Want %s", ps[i].Key, val, ps[i].Value)
|
||||
}
|
||||
}
|
||||
if val := ps.ByName("noKey"); val != "" {
|
||||
t.Errorf("Expected empty string for not found key; got: %s", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextParams(t *testing.T) {
|
||||
ps := Params{
|
||||
Param{"param1", "value1"},
|
||||
Param{"param2", "value2"},
|
||||
Param{"param3", "value3"},
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
req = req.WithContext(WithParams(req.Context(), ps))
|
||||
for i := range ps {
|
||||
if val := GetParam(req, ps[i].Key); val != ps[i].Value {
|
||||
t.Errorf("Wrong value for %s: Got %s; Want %s", ps[i].Key, val, ps[i].Value)
|
||||
}
|
||||
}
|
||||
if val := GetParam(req, "noKey"); val != "" {
|
||||
t.Errorf("Expected empty string for not found key; got: %s", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingContextParams(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
if value := GetParam(req, "whatever"); value != "" {
|
||||
t.Fatalf("wrong empty parameter value: got %v", value)
|
||||
}
|
||||
ps := GetParams(req)
|
||||
if value := ps.ByName("whatever"); value != "" {
|
||||
t.Fatalf("wrong empty parameter value: got %v", value)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// Based on the path package, Copyright 2009 The Go Authors.
|
||||
// Use of this source code is governed by a BSD-style license that can be found
|
||||
// in the LICENSE file.
|
||||
|
||||
package httprouter
|
||||
|
||||
// CleanPath is the URL version of path.Clean, it returns a canonical URL path
|
||||
// for p, eliminating . and .. elements.
|
||||
//
|
||||
// The following rules are applied iteratively until no further processing can
|
||||
// be done:
|
||||
// 1. Replace multiple slashes with a single slash.
|
||||
// 2. Eliminate each . path name element (the current directory).
|
||||
// 3. Eliminate each inner .. path name element (the parent directory)
|
||||
// along with the non-.. element that precedes it.
|
||||
// 4. Eliminate .. elements that begin a rooted path:
|
||||
// that is, replace "/.." by "/" at the beginning of a path.
|
||||
//
|
||||
// If the result of this process is an empty string, "/" is returned
|
||||
func CleanPath(p string) string {
|
||||
// Turn empty string into "/"
|
||||
if p == "" {
|
||||
return "/"
|
||||
}
|
||||
|
||||
n := len(p)
|
||||
var buf []byte
|
||||
|
||||
// Invariants:
|
||||
// reading from path; r is index of next byte to process.
|
||||
// writing to buf; w is index of next byte to write.
|
||||
|
||||
// path must start with '/'
|
||||
r := 1
|
||||
w := 1
|
||||
|
||||
if p[0] != '/' {
|
||||
r = 0
|
||||
buf = make([]byte, n+1)
|
||||
buf[0] = '/'
|
||||
}
|
||||
|
||||
trailing := n > 2 && p[n-1] == '/'
|
||||
|
||||
// A bit more clunky without a 'lazybuf' like the path package, but the loop
|
||||
// gets completely inlined (bufApp). So in contrast to the path package this
|
||||
// loop has no expensive function calls (except 1x make)
|
||||
|
||||
for r < n {
|
||||
switch {
|
||||
case p[r] == '/':
|
||||
// empty path element, trailing slash is added after the end
|
||||
r++
|
||||
|
||||
case p[r] == '.' && r+1 == n:
|
||||
trailing = true
|
||||
r++
|
||||
|
||||
case p[r] == '.' && p[r+1] == '/':
|
||||
// . element
|
||||
r++
|
||||
|
||||
case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'):
|
||||
// .. element: remove to last /
|
||||
r += 2
|
||||
|
||||
if w > 1 {
|
||||
// can backtrack
|
||||
w--
|
||||
|
||||
if buf == nil {
|
||||
for w > 1 && p[w] != '/' {
|
||||
w--
|
||||
}
|
||||
} else {
|
||||
for w > 1 && buf[w] != '/' {
|
||||
w--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// real path element.
|
||||
// add slash if needed
|
||||
if w > 1 {
|
||||
bufApp(&buf, p, w, '/')
|
||||
w++
|
||||
}
|
||||
|
||||
// copy element
|
||||
for r < n && p[r] != '/' {
|
||||
bufApp(&buf, p, w, p[r])
|
||||
w++
|
||||
r++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// re-append trailing slash
|
||||
if trailing && w > 1 {
|
||||
bufApp(&buf, p, w, '/')
|
||||
w++
|
||||
}
|
||||
|
||||
if buf == nil {
|
||||
return p[:w]
|
||||
}
|
||||
return string(buf[:w])
|
||||
}
|
||||
|
||||
// internal helper to lazily create a buffer if necessary
|
||||
func bufApp(buf *[]byte, s string, w int, c byte) {
|
||||
if *buf == nil {
|
||||
if s[w] == c {
|
||||
return
|
||||
}
|
||||
|
||||
*buf = make([]byte, len(s))
|
||||
copy(*buf, s[:w])
|
||||
}
|
||||
(*buf)[w] = c
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// Based on the path package, Copyright 2009 The Go Authors.
|
||||
// Use of this source code is governed by a BSD-style license that can be found
|
||||
// in the LICENSE file.
|
||||
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var cleanTests = []struct {
|
||||
path, result string
|
||||
}{
|
||||
// Already clean
|
||||
{"/", "/"},
|
||||
{"/abc", "/abc"},
|
||||
{"/a/b/c", "/a/b/c"},
|
||||
{"/abc/", "/abc/"},
|
||||
{"/a/b/c/", "/a/b/c/"},
|
||||
|
||||
// missing root
|
||||
{"", "/"},
|
||||
{"abc", "/abc"},
|
||||
{"abc/def", "/abc/def"},
|
||||
{"a/b/c", "/a/b/c"},
|
||||
|
||||
// Remove doubled slash
|
||||
{"//", "/"},
|
||||
{"/abc//", "/abc/"},
|
||||
{"/abc/def//", "/abc/def/"},
|
||||
{"/a/b/c//", "/a/b/c/"},
|
||||
{"/abc//def//ghi", "/abc/def/ghi"},
|
||||
{"//abc", "/abc"},
|
||||
{"///abc", "/abc"},
|
||||
{"//abc//", "/abc/"},
|
||||
|
||||
// Remove . elements
|
||||
{".", "/"},
|
||||
{"./", "/"},
|
||||
{"/abc/./def", "/abc/def"},
|
||||
{"/./abc/def", "/abc/def"},
|
||||
{"/abc/.", "/abc/"},
|
||||
|
||||
// Remove .. elements
|
||||
{"..", "/"},
|
||||
{"../", "/"},
|
||||
{"../../", "/"},
|
||||
{"../..", "/"},
|
||||
{"../../abc", "/abc"},
|
||||
{"/abc/def/ghi/../jkl", "/abc/def/jkl"},
|
||||
{"/abc/def/../ghi/../jkl", "/abc/jkl"},
|
||||
{"/abc/def/..", "/abc"},
|
||||
{"/abc/def/../..", "/"},
|
||||
{"/abc/def/../../..", "/"},
|
||||
{"/abc/def/../../..", "/"},
|
||||
{"/abc/def/../../../ghi/jkl/../../../mno", "/mno"},
|
||||
|
||||
// Combinations
|
||||
{"abc/./../def", "/def"},
|
||||
{"abc//./../def", "/def"},
|
||||
{"abc/../../././../def", "/def"},
|
||||
}
|
||||
|
||||
func TestPathClean(t *testing.T) {
|
||||
for _, test := range cleanTests {
|
||||
if s := CleanPath(test.path); s != test.result {
|
||||
t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result)
|
||||
}
|
||||
if s := CleanPath(test.result); s != test.result {
|
||||
t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathCleanMallocs(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping malloc count in short mode")
|
||||
}
|
||||
if runtime.GOMAXPROCS(0) > 1 {
|
||||
t.Log("skipping AllocsPerRun checks; GOMAXPROCS>1")
|
||||
return
|
||||
}
|
||||
|
||||
for _, test := range cleanTests {
|
||||
allocs := testing.AllocsPerRun(100, func() { CleanPath(test.result) })
|
||||
if allocs > 0 {
|
||||
t.Errorf("CleanPath(%q): %v allocs, want zero", test.result, allocs)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,375 @@
|
|||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be found
|
||||
// in the LICENSE file.
|
||||
|
||||
// Package httprouter is a trie based high performance HTTP request router.
|
||||
//
|
||||
// A trivial example is:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "fmt"
|
||||
// "github.com/bouk/httprouter"
|
||||
// "net/http"
|
||||
// "log"
|
||||
// )
|
||||
//
|
||||
// func Index(w http.ResponseWriter, r *http.Request) {
|
||||
// fmt.Fprint(w, "Welcome!\n")
|
||||
// }
|
||||
//
|
||||
// func Hello(w http.ResponseWriter, r *http.Request) {
|
||||
// fmt.Fprintf(w, "hello, %s!\n", httprouter.GetParam(r, "name"))
|
||||
// }
|
||||
//
|
||||
// func main() {
|
||||
// router := httprouter.New()
|
||||
// router.GET("/", Index)
|
||||
// router.GET("/hello/:name", Hello)
|
||||
//
|
||||
// log.Fatal(http.ListenAndServe(":8080", router))
|
||||
// }
|
||||
//
|
||||
// The router matches incoming requests by the request method and the path.
|
||||
// If a handle is registered for this path and method, the router delegates the
|
||||
// request to that function.
|
||||
// For the methods GET, POST, PUT, PATCH and DELETE shortcut functions exist to
|
||||
// register handles, for all other methods router.Handle can be used.
|
||||
//
|
||||
// The registered path, against which the router matches incoming requests, can
|
||||
// contain two types of parameters:
|
||||
// Syntax Type
|
||||
// :name named parameter
|
||||
// *name catch-all parameter
|
||||
//
|
||||
// Named parameters are dynamic path segments. They match anything until the
|
||||
// next '/' or the path end:
|
||||
// Path: /blog/:category/:post
|
||||
//
|
||||
// Requests:
|
||||
// /blog/go/request-routers match: category="go", post="request-routers"
|
||||
// /blog/go/request-routers/ no match, but the router would redirect
|
||||
// /blog/go/ no match
|
||||
// /blog/go/request-routers/comments no match
|
||||
//
|
||||
// Catch-all parameters match anything until the path end, including the
|
||||
// directory index (the '/' before the catch-all). Since they match anything
|
||||
// until the end, catch-all parameters must always be the final path element.
|
||||
// Path: /files/*filepath
|
||||
//
|
||||
// Requests:
|
||||
// /files/ match: filepath="/"
|
||||
// /files/LICENSE match: filepath="/LICENSE"
|
||||
// /files/templates/article.html match: filepath="/templates/article.html"
|
||||
// /files no match, but the router would redirect
|
||||
//
|
||||
// The value of parameters is saved as a slice of the Param struct, consisting
|
||||
// each of a key and a value. The slice is passed to the Handle func as a third
|
||||
// parameter.
|
||||
// There are two ways to retrieve the value of a parameter:
|
||||
// // by the name of the parameter
|
||||
// user := ps.ByName("user") // defined by :user or *user
|
||||
//
|
||||
// // by the index of the parameter. This way you can also get the name (key)
|
||||
// thirdKey := ps[2].Key // the name of the 3rd parameter
|
||||
// thirdValue := ps[2].Value // the value of the 3rd parameter
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Router is a http.Handler which can be used to dispatch requests to different
|
||||
// handler functions via configurable routes
|
||||
type Router struct {
|
||||
trees map[string]*node
|
||||
|
||||
// Enables automatic redirection if the current route can't be matched but a
|
||||
// handler for the path with (without) the trailing slash exists.
|
||||
// For example if /foo/ is requested but a route only exists for /foo, the
|
||||
// client is redirected to /foo with http status code 301 for GET requests
|
||||
// and 307 for all other request methods.
|
||||
RedirectTrailingSlash bool
|
||||
|
||||
// If enabled, the router tries to fix the current request path, if no
|
||||
// handle is registered for it.
|
||||
// First superfluous path elements like ../ or // are removed.
|
||||
// Afterwards the router does a case-insensitive lookup of the cleaned path.
|
||||
// If a handle can be found for this route, the router makes a redirection
|
||||
// to the corrected path with status code 301 for GET requests and 307 for
|
||||
// all other request methods.
|
||||
// For example /FOO and /..//Foo could be redirected to /foo.
|
||||
// RedirectTrailingSlash is independent of this option.
|
||||
RedirectFixedPath bool
|
||||
|
||||
// If enabled, the router checks if another method is allowed for the
|
||||
// current route, if the current request can not be routed.
|
||||
// If this is the case, the request is answered with 'Method Not Allowed'
|
||||
// and HTTP status code 405.
|
||||
// If no other Method is allowed, the request is delegated to the NotFound
|
||||
// handler.
|
||||
HandleMethodNotAllowed bool
|
||||
|
||||
// If enabled, the router automatically replies to OPTIONS requests.
|
||||
// Custom OPTIONS handlers take priority over automatic replies.
|
||||
HandleOPTIONS bool
|
||||
|
||||
// Configurable http.Handler which is called when no matching route is
|
||||
// found. If it is not set, http.NotFound is used.
|
||||
NotFound http.Handler
|
||||
|
||||
// Configurable http.Handler which is called when a request
|
||||
// cannot be routed and HandleMethodNotAllowed is true.
|
||||
// If it is not set, http.Error with http.StatusMethodNotAllowed is used.
|
||||
// The "Allow" header with allowed request methods is set before the handler
|
||||
// is called.
|
||||
MethodNotAllowed http.Handler
|
||||
|
||||
// Function to handle panics recovered from http handlers.
|
||||
// It should be used to generate a error page and return the http error code
|
||||
// 500 (Internal Server Error).
|
||||
// The handler can be used to keep your server from crashing because of
|
||||
// unrecovered panics.
|
||||
PanicHandler func(http.ResponseWriter, *http.Request, interface{})
|
||||
}
|
||||
|
||||
// Make sure the Router conforms with the http.Handler interface
|
||||
var _ http.Handler = New()
|
||||
|
||||
// New returns a new initialized Router.
|
||||
// Path auto-correction, including trailing slashes, is enabled by default.
|
||||
func New() *Router {
|
||||
return &Router{
|
||||
RedirectTrailingSlash: true,
|
||||
RedirectFixedPath: true,
|
||||
HandleMethodNotAllowed: true,
|
||||
HandleOPTIONS: true,
|
||||
}
|
||||
}
|
||||
|
||||
// GET is a shortcut for router.Handle("GET", path, handle)
|
||||
func (r *Router) GET(path string, handle http.HandlerFunc) {
|
||||
r.Handle("GET", path, handle)
|
||||
}
|
||||
|
||||
// HEAD is a shortcut for router.Handle("HEAD", path, handle)
|
||||
func (r *Router) HEAD(path string, handle http.HandlerFunc) {
|
||||
r.Handle("HEAD", path, handle)
|
||||
}
|
||||
|
||||
// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle)
|
||||
func (r *Router) OPTIONS(path string, handle http.HandlerFunc) {
|
||||
r.Handle("OPTIONS", path, handle)
|
||||
}
|
||||
|
||||
// POST is a shortcut for router.Handle("POST", path, handle)
|
||||
func (r *Router) POST(path string, handle http.HandlerFunc) {
|
||||
r.Handle("POST", path, handle)
|
||||
}
|
||||
|
||||
// PUT is a shortcut for router.Handle("PUT", path, handle)
|
||||
func (r *Router) PUT(path string, handle http.HandlerFunc) {
|
||||
r.Handle("PUT", path, handle)
|
||||
}
|
||||
|
||||
// PATCH is a shortcut for router.Handle("PATCH", path, handle)
|
||||
func (r *Router) PATCH(path string, handle http.HandlerFunc) {
|
||||
r.Handle("PATCH", path, handle)
|
||||
}
|
||||
|
||||
// DELETE is a shortcut for router.Handle("DELETE", path, handle)
|
||||
func (r *Router) DELETE(path string, handle http.HandlerFunc) {
|
||||
r.Handle("DELETE", path, handle)
|
||||
}
|
||||
|
||||
// Handle registers a new request handle with the given path and method.
|
||||
//
|
||||
// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut
|
||||
// functions can be used.
|
||||
//
|
||||
// This function is intended for bulk loading and to allow the usage of less
|
||||
// frequently used, non-standardized or custom methods (e.g. for internal
|
||||
// communication with a proxy).
|
||||
func (r *Router) Handle(method, path string, handle http.HandlerFunc) {
|
||||
r.Handler(method, path, handle)
|
||||
}
|
||||
|
||||
func (r *Router) Handler(method, path string, handler http.Handler) {
|
||||
if path[0] != '/' {
|
||||
panic("path must begin with '/' in path '" + path + "'")
|
||||
}
|
||||
|
||||
if r.trees == nil {
|
||||
r.trees = make(map[string]*node)
|
||||
}
|
||||
|
||||
root := r.trees[method]
|
||||
if root == nil {
|
||||
root = new(node)
|
||||
r.trees[method] = root
|
||||
}
|
||||
|
||||
root.addRoute(path, handler)
|
||||
}
|
||||
|
||||
// ServeFiles serves files from the given file system root.
|
||||
// The path must end with "/*filepath", files are then served from the local
|
||||
// path /defined/root/dir/*filepath.
|
||||
// For example if root is "/etc" and *filepath is "passwd", the local file
|
||||
// "/etc/passwd" would be served.
|
||||
// Internally a http.FileServer is used, therefore http.NotFound is used instead
|
||||
// of the Router's NotFound handler.
|
||||
// To use the operating system's file system implementation,
|
||||
// use http.Dir:
|
||||
// router.ServeFiles("/src/*filepath", http.Dir("/var/www"))
|
||||
func (r *Router) ServeFiles(path string, root http.FileSystem) {
|
||||
if len(path) < 10 || path[len(path)-10:] != "/*filepath" {
|
||||
panic("path must end with /*filepath in path '" + path + "'")
|
||||
}
|
||||
|
||||
fileServer := http.FileServer(root)
|
||||
|
||||
r.GET(path, func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL.Path = GetParam(req, "filepath")
|
||||
fileServer.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Router) recv(w http.ResponseWriter, req *http.Request) {
|
||||
if rcv := recover(); rcv != nil {
|
||||
r.PanicHandler(w, req, rcv)
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup allows the manual lookup of a method + path combo.
|
||||
// This is e.g. useful to build a framework around this router.
|
||||
// If the path was found, it returns the handle function and the path parameter
|
||||
// values. Otherwise the third return value indicates whether a redirection to
|
||||
// the same path with an extra / without the trailing slash should be performed.
|
||||
// It should be noted that the handler won't have access to the Params, unless
|
||||
// you add it to the request using WithParams
|
||||
func (r *Router) Lookup(method, path string) (http.Handler, Params, bool) {
|
||||
if root := r.trees[method]; root != nil {
|
||||
return root.getValue(path)
|
||||
}
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
func (r *Router) allowed(path, reqMethod string) (allow string) {
|
||||
if path == "*" { // server-wide
|
||||
for method := range r.trees {
|
||||
if method == "OPTIONS" {
|
||||
continue
|
||||
}
|
||||
|
||||
// add request method to list of allowed methods
|
||||
if len(allow) == 0 {
|
||||
allow = method
|
||||
} else {
|
||||
allow += ", " + method
|
||||
}
|
||||
}
|
||||
} else { // specific path
|
||||
for method := range r.trees {
|
||||
// Skip the requested method - we already tried this one
|
||||
if method == reqMethod || method == "OPTIONS" {
|
||||
continue
|
||||
}
|
||||
|
||||
handle, _, _ := r.trees[method].getValue(path)
|
||||
if handle != nil {
|
||||
// add request method to list of allowed methods
|
||||
if len(allow) == 0 {
|
||||
allow = method
|
||||
} else {
|
||||
allow += ", " + method
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(allow) > 0 {
|
||||
allow += ", OPTIONS"
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ServeHTTP makes the router implement the http.Handler interface.
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if r.PanicHandler != nil {
|
||||
defer r.recv(w, req)
|
||||
}
|
||||
|
||||
path := req.URL.Path
|
||||
|
||||
if root := r.trees[req.Method]; root != nil {
|
||||
if handle, ps, tsr := root.getValue(path); handle != nil {
|
||||
req = req.WithContext(WithParams(req.Context(), ps))
|
||||
handle.ServeHTTP(w, req)
|
||||
return
|
||||
} else if req.Method != "CONNECT" && path != "/" {
|
||||
code := 301 // Permanent redirect, request with GET method
|
||||
if req.Method != "GET" {
|
||||
// Temporary redirect, request with same method
|
||||
// As of Go 1.3, Go does not support status code 308.
|
||||
code = 307
|
||||
}
|
||||
|
||||
if tsr && r.RedirectTrailingSlash {
|
||||
if len(path) > 1 && path[len(path)-1] == '/' {
|
||||
req.URL.Path = path[:len(path)-1]
|
||||
} else {
|
||||
req.URL.Path = path + "/"
|
||||
}
|
||||
http.Redirect(w, req, req.URL.String(), code)
|
||||
return
|
||||
}
|
||||
|
||||
// Try to fix the request path
|
||||
if r.RedirectFixedPath {
|
||||
fixedPath, found := root.findCaseInsensitivePath(
|
||||
CleanPath(path),
|
||||
r.RedirectTrailingSlash,
|
||||
)
|
||||
if found {
|
||||
req.URL.Path = string(fixedPath)
|
||||
http.Redirect(w, req, req.URL.String(), code)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.Method == "OPTIONS" {
|
||||
// Handle OPTIONS requests
|
||||
if r.HandleOPTIONS {
|
||||
if allow := r.allowed(path, req.Method); len(allow) > 0 {
|
||||
w.Header().Set("Allow", allow)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle 405
|
||||
if r.HandleMethodNotAllowed {
|
||||
if allow := r.allowed(path, req.Method); len(allow) > 0 {
|
||||
w.Header().Set("Allow", allow)
|
||||
if r.MethodNotAllowed != nil {
|
||||
r.MethodNotAllowed.ServeHTTP(w, req)
|
||||
} else {
|
||||
http.Error(w,
|
||||
http.StatusText(http.StatusMethodNotAllowed),
|
||||
http.StatusMethodNotAllowed,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 404
|
||||
if r.NotFound != nil {
|
||||
r.NotFound.ServeHTTP(w, req)
|
||||
} else {
|
||||
http.NotFound(w, req)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,515 @@
|
|||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be found
|
||||
// in the LICENSE file.
|
||||
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockResponseWriter struct{}
|
||||
|
||||
func (m *mockResponseWriter) Header() (h http.Header) {
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) WriteString(s string) (n int, err error) {
|
||||
return len(s), nil
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) WriteHeader(int) {}
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
router := New()
|
||||
|
||||
routed := false
|
||||
router.Handle("GET", "/user/:name", func(w http.ResponseWriter, r *http.Request) {
|
||||
ps := GetParams(r)
|
||||
routed = true
|
||||
want := Params{Param{"name", "gopher"}}
|
||||
if !reflect.DeepEqual(ps, want) {
|
||||
t.Fatalf("wrong wildcard values: want %v, got %v", want, ps)
|
||||
}
|
||||
})
|
||||
|
||||
w := new(mockResponseWriter)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/user/gopher", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if !routed {
|
||||
t.Fatal("routing failed")
|
||||
}
|
||||
}
|
||||
|
||||
type handlerStruct struct {
|
||||
handeled *bool
|
||||
}
|
||||
|
||||
func (h handlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
*h.handeled = true
|
||||
}
|
||||
|
||||
func TestRouterAPI(t *testing.T) {
|
||||
var get, head, options, post, put, patch, delete, handler, handlerFunc bool
|
||||
|
||||
httpHandler := handlerStruct{&handler}
|
||||
|
||||
router := New()
|
||||
router.GET("/GET", func(w http.ResponseWriter, r *http.Request) {
|
||||
get = true
|
||||
})
|
||||
router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request) {
|
||||
head = true
|
||||
})
|
||||
router.OPTIONS("/GET", func(w http.ResponseWriter, r *http.Request) {
|
||||
options = true
|
||||
})
|
||||
router.POST("/POST", func(w http.ResponseWriter, r *http.Request) {
|
||||
post = true
|
||||
})
|
||||
router.PUT("/PUT", func(w http.ResponseWriter, r *http.Request) {
|
||||
put = true
|
||||
})
|
||||
router.PATCH("/PATCH", func(w http.ResponseWriter, r *http.Request) {
|
||||
patch = true
|
||||
})
|
||||
router.DELETE("/DELETE", func(w http.ResponseWriter, r *http.Request) {
|
||||
delete = true
|
||||
})
|
||||
router.Handler("GET", "/Handler", httpHandler)
|
||||
router.Handle("GET", "/HandlerFunc", func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerFunc = true
|
||||
})
|
||||
|
||||
w := new(mockResponseWriter)
|
||||
|
||||
r, _ := http.NewRequest("GET", "/GET", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !get {
|
||||
t.Error("routing GET failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("HEAD", "/GET", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !head {
|
||||
t.Error("routing HEAD failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("OPTIONS", "/GET", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !options {
|
||||
t.Error("routing OPTIONS failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("POST", "/POST", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !post {
|
||||
t.Error("routing POST failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("PUT", "/PUT", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !put {
|
||||
t.Error("routing PUT failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("PATCH", "/PATCH", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !patch {
|
||||
t.Error("routing PATCH failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("DELETE", "/DELETE", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !delete {
|
||||
t.Error("routing DELETE failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("GET", "/Handler", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !handler {
|
||||
t.Error("routing Handler failed")
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("GET", "/HandlerFunc", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !handlerFunc {
|
||||
t.Error("routing HandlerFunc failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterRoot(t *testing.T) {
|
||||
router := New()
|
||||
recv := catchPanic(func() {
|
||||
router.GET("noSlashRoot", nil)
|
||||
})
|
||||
if recv == nil {
|
||||
t.Fatal("registering path not beginning with '/' did not panic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterChaining(t *testing.T) {
|
||||
router1 := New()
|
||||
router2 := New()
|
||||
router1.NotFound = router2
|
||||
|
||||
fooHit := false
|
||||
router1.POST("/foo", func(w http.ResponseWriter, req *http.Request) {
|
||||
fooHit = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
barHit := false
|
||||
router2.POST("/bar", func(w http.ResponseWriter, req *http.Request) {
|
||||
barHit = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, _ := http.NewRequest("POST", "/foo", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router1.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK && fooHit) {
|
||||
t.Errorf("Regular routing failed with router chaining.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("POST", "/bar", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router1.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK && barHit) {
|
||||
t.Errorf("Chained routing failed with router chaining.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("POST", "/qax", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router1.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusNotFound) {
|
||||
t.Errorf("NotFound behavior failed with router chaining.")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterOPTIONS(t *testing.T) {
|
||||
handlerFunc := func(_ http.ResponseWriter, _ *http.Request) {}
|
||||
|
||||
router := New()
|
||||
router.POST("/path", handlerFunc)
|
||||
|
||||
// test not allowed
|
||||
// * (server)
|
||||
r, _ := http.NewRequest("OPTIONS", "*", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK) {
|
||||
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
} else if allow := w.Header().Get("Allow"); allow != "POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
|
||||
// path
|
||||
r, _ = http.NewRequest("OPTIONS", "/path", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK) {
|
||||
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
} else if allow := w.Header().Get("Allow"); allow != "POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
|
||||
r, _ = http.NewRequest("OPTIONS", "/doesnotexist", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusNotFound) {
|
||||
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
}
|
||||
|
||||
// add another method
|
||||
router.GET("/path", handlerFunc)
|
||||
|
||||
// test again
|
||||
// * (server)
|
||||
r, _ = http.NewRequest("OPTIONS", "*", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK) {
|
||||
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
} else if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
|
||||
// path
|
||||
r, _ = http.NewRequest("OPTIONS", "/path", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK) {
|
||||
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
} else if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
|
||||
// custom handler
|
||||
var custom bool
|
||||
router.OPTIONS("/path", func(w http.ResponseWriter, r *http.Request) {
|
||||
custom = true
|
||||
})
|
||||
|
||||
// test again
|
||||
// * (server)
|
||||
r, _ = http.NewRequest("OPTIONS", "*", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK) {
|
||||
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
} else if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
if custom {
|
||||
t.Error("custom handler called on *")
|
||||
}
|
||||
|
||||
// path
|
||||
r, _ = http.NewRequest("OPTIONS", "/path", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusOK) {
|
||||
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
}
|
||||
if !custom {
|
||||
t.Error("custom handler not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterNotAllowed(t *testing.T) {
|
||||
handlerFunc := func(_ http.ResponseWriter, _ *http.Request) {}
|
||||
|
||||
router := New()
|
||||
router.POST("/path", handlerFunc)
|
||||
|
||||
// test not allowed
|
||||
r, _ := http.NewRequest("GET", "/path", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusMethodNotAllowed) {
|
||||
t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
} else if allow := w.Header().Get("Allow"); allow != "POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
|
||||
// add another method
|
||||
router.DELETE("/path", handlerFunc)
|
||||
router.OPTIONS("/path", handlerFunc) // must be ignored
|
||||
|
||||
// test again
|
||||
r, _ = http.NewRequest("GET", "/path", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == http.StatusMethodNotAllowed) {
|
||||
t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
} else if allow := w.Header().Get("Allow"); allow != "POST, DELETE, OPTIONS" && allow != "DELETE, POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
|
||||
// test custom handler
|
||||
w = httptest.NewRecorder()
|
||||
responseText := "custom method"
|
||||
router.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
w.Write([]byte(responseText))
|
||||
})
|
||||
router.ServeHTTP(w, r)
|
||||
if got := w.Body.String(); !(got == responseText) {
|
||||
t.Errorf("unexpected response got %q want %q", got, responseText)
|
||||
}
|
||||
if w.Code != http.StatusTeapot {
|
||||
t.Errorf("unexpected response code %d want %d", w.Code, http.StatusTeapot)
|
||||
}
|
||||
if allow := w.Header().Get("Allow"); allow != "POST, DELETE, OPTIONS" && allow != "DELETE, POST, OPTIONS" {
|
||||
t.Error("unexpected Allow header value: " + allow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterNotFound(t *testing.T) {
|
||||
handlerFunc := func(_ http.ResponseWriter, _ *http.Request) {}
|
||||
|
||||
router := New()
|
||||
router.GET("/path", handlerFunc)
|
||||
router.GET("/dir/", handlerFunc)
|
||||
router.GET("/", handlerFunc)
|
||||
|
||||
testRoutes := []struct {
|
||||
route string
|
||||
code int
|
||||
header string
|
||||
}{
|
||||
{"/path/", 301, "map[Location:[/path]]"}, // TSR -/
|
||||
{"/dir", 301, "map[Location:[/dir/]]"}, // TSR +/
|
||||
{"", 301, "map[Location:[/]]"}, // TSR +/
|
||||
{"/PATH", 301, "map[Location:[/path]]"}, // Fixed Case
|
||||
{"/DIR/", 301, "map[Location:[/dir/]]"}, // Fixed Case
|
||||
{"/PATH/", 301, "map[Location:[/path]]"}, // Fixed Case -/
|
||||
{"/DIR", 301, "map[Location:[/dir/]]"}, // Fixed Case +/
|
||||
{"/../path", 301, "map[Location:[/path]]"}, // CleanPath
|
||||
{"/nope", 404, ""}, // NotFound
|
||||
}
|
||||
for _, tr := range testRoutes {
|
||||
r, _ := http.NewRequest("GET", tr.route, nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == tr.code && (w.Code == 404 || fmt.Sprint(w.Header()) == tr.header)) {
|
||||
t.Errorf("NotFound handling route %s failed: Code=%d, Header=%v", tr.route, w.Code, w.Header())
|
||||
}
|
||||
}
|
||||
|
||||
// Test custom not found handler
|
||||
var notFound bool
|
||||
router.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(404)
|
||||
notFound = true
|
||||
})
|
||||
r, _ := http.NewRequest("GET", "/nope", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == 404 && notFound == true) {
|
||||
t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
}
|
||||
|
||||
// Test other method than GET (want 307 instead of 301)
|
||||
router.PATCH("/path", handlerFunc)
|
||||
r, _ = http.NewRequest("PATCH", "/path/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == 307 && fmt.Sprint(w.Header()) == "map[Location:[/path]]") {
|
||||
t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", w.Code, w.Header())
|
||||
}
|
||||
|
||||
// Test special case where no node for the prefix "/" exists
|
||||
router = New()
|
||||
router.GET("/a", handlerFunc)
|
||||
r, _ = http.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, r)
|
||||
if !(w.Code == 404) {
|
||||
t.Errorf("NotFound handling route / failed: Code=%d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterPanicHandler(t *testing.T) {
|
||||
router := New()
|
||||
panicHandled := false
|
||||
|
||||
router.PanicHandler = func(rw http.ResponseWriter, r *http.Request, p interface{}) {
|
||||
panicHandled = true
|
||||
}
|
||||
|
||||
router.Handle("PUT", "/user/:name", func(_ http.ResponseWriter, _ *http.Request) {
|
||||
panic("oops!")
|
||||
})
|
||||
|
||||
w := new(mockResponseWriter)
|
||||
req, _ := http.NewRequest("PUT", "/user/gopher", nil)
|
||||
|
||||
defer func() {
|
||||
if rcv := recover(); rcv != nil {
|
||||
t.Fatal("handling panic failed")
|
||||
}
|
||||
}()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if !panicHandled {
|
||||
t.Fatal("simulating failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterLookup(t *testing.T) {
|
||||
routed := false
|
||||
wantHandle := func(_ http.ResponseWriter, _ *http.Request) {
|
||||
routed = true
|
||||
}
|
||||
wantParams := Params{Param{"name", "gopher"}}
|
||||
|
||||
router := New()
|
||||
|
||||
// try empty router first
|
||||
handle, _, tsr := router.Lookup("GET", "/nope")
|
||||
if handle != nil {
|
||||
t.Fatalf("Got handle for unregistered pattern: %v", handle)
|
||||
}
|
||||
if tsr {
|
||||
t.Error("Got wrong TSR recommendation!")
|
||||
}
|
||||
|
||||
// insert route and try again
|
||||
router.GET("/user/:name", wantHandle)
|
||||
|
||||
handle, params, tsr := router.Lookup("GET", "/user/gopher")
|
||||
if handle == nil {
|
||||
t.Fatal("Got no handle!")
|
||||
} else {
|
||||
handle.ServeHTTP(nil, nil)
|
||||
if !routed {
|
||||
t.Fatal("Routing failed!")
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(params, wantParams) {
|
||||
t.Fatalf("Wrong parameter values: want %v, got %v", wantParams, params)
|
||||
}
|
||||
|
||||
handle, _, tsr = router.Lookup("GET", "/user/gopher/")
|
||||
if handle != nil {
|
||||
t.Fatalf("Got handle for unregistered pattern: %v", handle)
|
||||
}
|
||||
if !tsr {
|
||||
t.Error("Got no TSR recommendation!")
|
||||
}
|
||||
|
||||
handle, _, tsr = router.Lookup("GET", "/nope")
|
||||
if handle != nil {
|
||||
t.Fatalf("Got handle for unregistered pattern: %v", handle)
|
||||
}
|
||||
if tsr {
|
||||
t.Error("Got wrong TSR recommendation!")
|
||||
}
|
||||
}
|
||||
|
||||
type mockFileSystem struct {
|
||||
opened bool
|
||||
}
|
||||
|
||||
func (mfs *mockFileSystem) Open(name string) (http.File, error) {
|
||||
mfs.opened = true
|
||||
return nil, errors.New("this is just a mock")
|
||||
}
|
||||
|
||||
func TestRouterServeFiles(t *testing.T) {
|
||||
router := New()
|
||||
mfs := &mockFileSystem{}
|
||||
|
||||
recv := catchPanic(func() {
|
||||
router.ServeFiles("/noFilepath", mfs)
|
||||
})
|
||||
if recv == nil {
|
||||
t.Fatal("registering path not ending with '*filepath' did not panic")
|
||||
}
|
||||
|
||||
router.ServeFiles("/*filepath", mfs)
|
||||
w := new(mockResponseWriter)
|
||||
r, _ := http.NewRequest("GET", "/favicon.ico", nil)
|
||||
router.ServeHTTP(w, r)
|
||||
if !mfs.opened {
|
||||
t.Error("serving file failed")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,650 @@
|
|||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be found
|
||||
// in the LICENSE file.
|
||||
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func min(a, b int) int {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func countParams(path string) uint8 {
|
||||
var n uint
|
||||
for i := 0; i < len(path); i++ {
|
||||
if path[i] != ':' && path[i] != '*' {
|
||||
continue
|
||||
}
|
||||
n++
|
||||
}
|
||||
if n >= 255 {
|
||||
return 255
|
||||
}
|
||||
return uint8(n)
|
||||
}
|
||||
|
||||
type nodeType uint8
|
||||
|
||||
const (
|
||||
static nodeType = iota // default
|
||||
root
|
||||
param
|
||||
catchAll
|
||||
)
|
||||
|
||||
type node struct {
|
||||
path string
|
||||
wildChild bool
|
||||
nType nodeType
|
||||
maxParams uint8
|
||||
indices string
|
||||
children []*node
|
||||
handle http.Handler
|
||||
priority uint32
|
||||
}
|
||||
|
||||
// increments priority of the given child and reorders if necessary
|
||||
func (n *node) incrementChildPrio(pos int) int {
|
||||
n.children[pos].priority++
|
||||
prio := n.children[pos].priority
|
||||
|
||||
// adjust position (move to front)
|
||||
newPos := pos
|
||||
for newPos > 0 && n.children[newPos-1].priority < prio {
|
||||
// swap node positions
|
||||
tmpN := n.children[newPos-1]
|
||||
n.children[newPos-1] = n.children[newPos]
|
||||
n.children[newPos] = tmpN
|
||||
|
||||
newPos--
|
||||
}
|
||||
|
||||
// build new index char string
|
||||
if newPos != pos {
|
||||
n.indices = n.indices[:newPos] + // unchanged prefix, might be empty
|
||||
n.indices[pos:pos+1] + // the index char we move
|
||||
n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos'
|
||||
}
|
||||
|
||||
return newPos
|
||||
}
|
||||
|
||||
// addRoute adds a node with the given handle to the path.
|
||||
// Not concurrency-safe!
|
||||
func (n *node) addRoute(path string, handle http.Handler) {
|
||||
fullPath := path
|
||||
n.priority++
|
||||
numParams := countParams(path)
|
||||
|
||||
// non-empty tree
|
||||
if len(n.path) > 0 || len(n.children) > 0 {
|
||||
walk:
|
||||
for {
|
||||
// Update maxParams of the current node
|
||||
if numParams > n.maxParams {
|
||||
n.maxParams = numParams
|
||||
}
|
||||
|
||||
// Find the longest common prefix.
|
||||
// This also implies that the common prefix contains no ':' or '*'
|
||||
// since the existing key can't contain those chars.
|
||||
i := 0
|
||||
max := min(len(path), len(n.path))
|
||||
for i < max && path[i] == n.path[i] {
|
||||
i++
|
||||
}
|
||||
|
||||
// Split edge
|
||||
if i < len(n.path) {
|
||||
child := node{
|
||||
path: n.path[i:],
|
||||
wildChild: n.wildChild,
|
||||
nType: static,
|
||||
indices: n.indices,
|
||||
children: n.children,
|
||||
handle: n.handle,
|
||||
priority: n.priority - 1,
|
||||
}
|
||||
|
||||
// Update maxParams (max of all children)
|
||||
for i := range child.children {
|
||||
if child.children[i].maxParams > child.maxParams {
|
||||
child.maxParams = child.children[i].maxParams
|
||||
}
|
||||
}
|
||||
|
||||
n.children = []*node{&child}
|
||||
// []byte for proper unicode char conversion, see #65
|
||||
n.indices = string([]byte{n.path[i]})
|
||||
n.path = path[:i]
|
||||
n.handle = nil
|
||||
n.wildChild = false
|
||||
}
|
||||
|
||||
// Make new node a child of this node
|
||||
if i < len(path) {
|
||||
path = path[i:]
|
||||
|
||||
if n.wildChild {
|
||||
n = n.children[0]
|
||||
n.priority++
|
||||
|
||||
// Update maxParams of the child node
|
||||
if numParams > n.maxParams {
|
||||
n.maxParams = numParams
|
||||
}
|
||||
numParams--
|
||||
|
||||
// Check if the wildcard matches
|
||||
if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
|
||||
// check for longer wildcard, e.g. :name and :names
|
||||
if len(n.path) >= len(path) || path[len(n.path)] == '/' {
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
|
||||
panic("path segment '" + path +
|
||||
"' conflicts with existing wildcard '" + n.path +
|
||||
"' in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
c := path[0]
|
||||
|
||||
// slash after param
|
||||
if n.nType == param && c == '/' && len(n.children) == 1 {
|
||||
n = n.children[0]
|
||||
n.priority++
|
||||
continue walk
|
||||
}
|
||||
|
||||
// Check if a child with the next path byte exists
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if c == n.indices[i] {
|
||||
i = n.incrementChildPrio(i)
|
||||
n = n.children[i]
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise insert it
|
||||
if c != ':' && c != '*' {
|
||||
// []byte for proper unicode char conversion, see #65
|
||||
n.indices += string([]byte{c})
|
||||
child := &node{
|
||||
maxParams: numParams,
|
||||
}
|
||||
n.children = append(n.children, child)
|
||||
n.incrementChildPrio(len(n.indices) - 1)
|
||||
n = child
|
||||
}
|
||||
n.insertChild(numParams, path, fullPath, handle)
|
||||
return
|
||||
|
||||
} else if i == len(path) { // Make node a (in-path) leaf
|
||||
if n.handle != nil {
|
||||
panic("a handle is already registered for path '" + fullPath + "'")
|
||||
}
|
||||
n.handle = handle
|
||||
}
|
||||
return
|
||||
}
|
||||
} else { // Empty tree
|
||||
n.insertChild(numParams, path, fullPath, handle)
|
||||
n.nType = root
|
||||
}
|
||||
}
|
||||
|
||||
func (n *node) insertChild(numParams uint8, path, fullPath string, handle http.Handler) {
|
||||
var offset int // already handled bytes of the path
|
||||
|
||||
// find prefix until first wildcard (beginning with ':'' or '*'')
|
||||
for i, max := 0, len(path); numParams > 0; i++ {
|
||||
c := path[i]
|
||||
if c != ':' && c != '*' {
|
||||
continue
|
||||
}
|
||||
|
||||
// find wildcard end (either '/' or path end)
|
||||
end := i + 1
|
||||
for end < max && path[end] != '/' {
|
||||
switch path[end] {
|
||||
// the wildcard name must not contain ':' and '*'
|
||||
case ':', '*':
|
||||
panic("only one wildcard per path segment is allowed, has: '" +
|
||||
path[i:] + "' in path '" + fullPath + "'")
|
||||
default:
|
||||
end++
|
||||
}
|
||||
}
|
||||
|
||||
// check if this Node existing children which would be
|
||||
// unreachable if we insert the wildcard here
|
||||
if len(n.children) > 0 {
|
||||
panic("wildcard route '" + path[i:end] +
|
||||
"' conflicts with existing children in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
// check if the wildcard has a name
|
||||
if end-i < 2 {
|
||||
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
if c == ':' { // param
|
||||
// split path at the beginning of the wildcard
|
||||
if i > 0 {
|
||||
n.path = path[offset:i]
|
||||
offset = i
|
||||
}
|
||||
|
||||
child := &node{
|
||||
nType: param,
|
||||
maxParams: numParams,
|
||||
}
|
||||
n.children = []*node{child}
|
||||
n.wildChild = true
|
||||
n = child
|
||||
n.priority++
|
||||
numParams--
|
||||
|
||||
// if the path doesn't end with the wildcard, then there
|
||||
// will be another non-wildcard subpath starting with '/'
|
||||
if end < max {
|
||||
n.path = path[offset:end]
|
||||
offset = end
|
||||
|
||||
child := &node{
|
||||
maxParams: numParams,
|
||||
priority: 1,
|
||||
}
|
||||
n.children = []*node{child}
|
||||
n = child
|
||||
}
|
||||
|
||||
} else { // catchAll
|
||||
if end != max || numParams > 1 {
|
||||
panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
|
||||
panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
// currently fixed width 1 for '/'
|
||||
i--
|
||||
if path[i] != '/' {
|
||||
panic("no / before catch-all in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
n.path = path[offset:i]
|
||||
|
||||
// first node: catchAll node with empty path
|
||||
child := &node{
|
||||
wildChild: true,
|
||||
nType: catchAll,
|
||||
maxParams: 1,
|
||||
}
|
||||
n.children = []*node{child}
|
||||
n.indices = string(path[i])
|
||||
n = child
|
||||
n.priority++
|
||||
|
||||
// second node: node holding the variable
|
||||
child = &node{
|
||||
path: path[i:],
|
||||
nType: catchAll,
|
||||
maxParams: 1,
|
||||
handle: handle,
|
||||
priority: 1,
|
||||
}
|
||||
n.children = []*node{child}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// insert remaining path part and handle to the leaf
|
||||
n.path = path[offset:]
|
||||
n.handle = handle
|
||||
}
|
||||
|
||||
// Returns the handle registered with the given path (key). The values of
|
||||
// wildcards are saved to a map.
|
||||
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
|
||||
// made if a handle exists with an extra (without the) trailing slash for the
|
||||
// given path.
|
||||
func (n *node) getValue(path string) (handle http.Handler, p Params, tsr bool) {
|
||||
walk: // outer loop for walking the tree
|
||||
for {
|
||||
if len(path) > len(n.path) {
|
||||
if path[:len(n.path)] == n.path {
|
||||
path = path[len(n.path):]
|
||||
// If this node does not have a wildcard (param or catchAll)
|
||||
// child, we can just look up the next child node and continue
|
||||
// to walk down the tree
|
||||
if !n.wildChild {
|
||||
c := path[0]
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if c == n.indices[i] {
|
||||
n = n.children[i]
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing found.
|
||||
// We can recommend to redirect to the same URL without a
|
||||
// trailing slash if a leaf exists for that path.
|
||||
tsr = (path == "/" && n.handle != nil)
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// handle wildcard child
|
||||
n = n.children[0]
|
||||
switch n.nType {
|
||||
case param:
|
||||
// find param end (either '/' or path end)
|
||||
end := 0
|
||||
for end < len(path) && path[end] != '/' {
|
||||
end++
|
||||
}
|
||||
|
||||
// save param value
|
||||
if p == nil {
|
||||
// lazy allocation
|
||||
p = make(Params, 0, n.maxParams)
|
||||
}
|
||||
i := len(p)
|
||||
p = p[:i+1] // expand slice within preallocated capacity
|
||||
p[i].Key = n.path[1:]
|
||||
p[i].Value = path[:end]
|
||||
|
||||
// we need to go deeper!
|
||||
if end < len(path) {
|
||||
if len(n.children) > 0 {
|
||||
path = path[end:]
|
||||
n = n.children[0]
|
||||
continue walk
|
||||
}
|
||||
|
||||
// ... but we can't
|
||||
tsr = (len(path) == end+1)
|
||||
return
|
||||
}
|
||||
|
||||
if handle = n.handle; handle != nil {
|
||||
return
|
||||
} else if len(n.children) == 1 {
|
||||
// No handle found. Check if a handle for this path + a
|
||||
// trailing slash exists for TSR recommendation
|
||||
n = n.children[0]
|
||||
tsr = (n.path == "/" && n.handle != nil)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
case catchAll:
|
||||
// save param value
|
||||
if p == nil {
|
||||
// lazy allocation
|
||||
p = make(Params, 0, n.maxParams)
|
||||
}
|
||||
i := len(p)
|
||||
p = p[:i+1] // expand slice within preallocated capacity
|
||||
p[i].Key = n.path[2:]
|
||||
p[i].Value = path
|
||||
|
||||
handle = n.handle
|
||||
return
|
||||
|
||||
default:
|
||||
panic("invalid node type")
|
||||
}
|
||||
}
|
||||
} else if path == n.path {
|
||||
// We should have reached the node containing the handle.
|
||||
// Check if this node has a handle registered.
|
||||
if handle = n.handle; handle != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if path == "/" && n.wildChild && n.nType != root {
|
||||
tsr = true
|
||||
return
|
||||
}
|
||||
|
||||
// No handle found. Check if a handle for this path + a
|
||||
// trailing slash exists for trailing slash recommendation
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == '/' {
|
||||
n = n.children[i]
|
||||
tsr = (len(n.path) == 1 && n.handle != nil) ||
|
||||
(n.nType == catchAll && n.children[0].handle != nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Nothing found. We can recommend to redirect to the same URL with an
|
||||
// extra trailing slash if a leaf exists for that path
|
||||
tsr = (path == "/") ||
|
||||
(len(n.path) == len(path)+1 && n.path[len(path)] == '/' &&
|
||||
path == n.path[:len(n.path)-1] && n.handle != nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Makes a case-insensitive lookup of the given path and tries to find a handler.
|
||||
// It can optionally also fix trailing slashes.
|
||||
// It returns the case-corrected path and a bool indicating whether the lookup
|
||||
// was successful.
|
||||
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) {
|
||||
return n.findCaseInsensitivePathRec(
|
||||
path,
|
||||
strings.ToLower(path),
|
||||
make([]byte, 0, len(path)+1), // preallocate enough memory for new path
|
||||
[4]byte{}, // empty rune buffer
|
||||
fixTrailingSlash,
|
||||
)
|
||||
}
|
||||
|
||||
// shift bytes in array by n bytes left
|
||||
func shiftNRuneBytes(rb [4]byte, n int) [4]byte {
|
||||
switch n {
|
||||
case 0:
|
||||
return rb
|
||||
case 1:
|
||||
return [4]byte{rb[1], rb[2], rb[3], 0}
|
||||
case 2:
|
||||
return [4]byte{rb[2], rb[3]}
|
||||
case 3:
|
||||
return [4]byte{rb[3]}
|
||||
default:
|
||||
return [4]byte{}
|
||||
}
|
||||
}
|
||||
|
||||
// recursive case-insensitive lookup function used by n.findCaseInsensitivePath
|
||||
func (n *node) findCaseInsensitivePathRec(path, loPath string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) ([]byte, bool) {
|
||||
loNPath := strings.ToLower(n.path)
|
||||
|
||||
walk: // outer loop for walking the tree
|
||||
for len(loPath) >= len(loNPath) && (len(loNPath) == 0 || loPath[1:len(loNPath)] == loNPath[1:]) {
|
||||
// add common path to result
|
||||
ciPath = append(ciPath, n.path...)
|
||||
|
||||
if path = path[len(n.path):]; len(path) > 0 {
|
||||
loOld := loPath
|
||||
loPath = loPath[len(loNPath):]
|
||||
|
||||
// If this node does not have a wildcard (param or catchAll) child,
|
||||
// we can just look up the next child node and continue to walk down
|
||||
// the tree
|
||||
if !n.wildChild {
|
||||
// skip rune bytes already processed
|
||||
rb = shiftNRuneBytes(rb, len(loNPath))
|
||||
|
||||
if rb[0] != 0 {
|
||||
// old rune not finished
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == rb[0] {
|
||||
// continue with child node
|
||||
n = n.children[i]
|
||||
loNPath = strings.ToLower(n.path)
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// process a new rune
|
||||
var rv rune
|
||||
|
||||
// find rune start
|
||||
// runes are up to 4 byte long,
|
||||
// -4 would definitely be another rune
|
||||
var off int
|
||||
for max := min(len(loNPath), 3); off < max; off++ {
|
||||
if i := len(loNPath) - off; utf8.RuneStart(loOld[i]) {
|
||||
// read rune from cached lowercase path
|
||||
rv, _ = utf8.DecodeRuneInString(loOld[i:])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// calculate lowercase bytes of current rune
|
||||
utf8.EncodeRune(rb[:], rv)
|
||||
// skipp already processed bytes
|
||||
rb = shiftNRuneBytes(rb, off)
|
||||
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
// lowercase matches
|
||||
if n.indices[i] == rb[0] {
|
||||
// must use a recursive approach since both the
|
||||
// uppercase byte and the lowercase byte might exist
|
||||
// as an index
|
||||
if out, found := n.children[i].findCaseInsensitivePathRec(
|
||||
path, loPath, ciPath, rb, fixTrailingSlash,
|
||||
); found {
|
||||
return out, true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// same for uppercase rune, if it differs
|
||||
if up := unicode.ToUpper(rv); up != rv {
|
||||
utf8.EncodeRune(rb[:], up)
|
||||
rb = shiftNRuneBytes(rb, off)
|
||||
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
// uppercase matches
|
||||
if n.indices[i] == rb[0] {
|
||||
// continue with child node
|
||||
n = n.children[i]
|
||||
loNPath = strings.ToLower(n.path)
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing found. We can recommend to redirect to the same URL
|
||||
// without a trailing slash if a leaf exists for that path
|
||||
return ciPath, (fixTrailingSlash && path == "/" && n.handle != nil)
|
||||
}
|
||||
|
||||
n = n.children[0]
|
||||
switch n.nType {
|
||||
case param:
|
||||
// find param end (either '/' or path end)
|
||||
k := 0
|
||||
for k < len(path) && path[k] != '/' {
|
||||
k++
|
||||
}
|
||||
|
||||
// add param value to case insensitive path
|
||||
ciPath = append(ciPath, path[:k]...)
|
||||
|
||||
// we need to go deeper!
|
||||
if k < len(path) {
|
||||
if len(n.children) > 0 {
|
||||
// continue with child node
|
||||
n = n.children[0]
|
||||
loNPath = strings.ToLower(n.path)
|
||||
loPath = loPath[k:]
|
||||
path = path[k:]
|
||||
continue
|
||||
}
|
||||
|
||||
// ... but we can't
|
||||
if fixTrailingSlash && len(path) == k+1 {
|
||||
return ciPath, true
|
||||
}
|
||||
return ciPath, false
|
||||
}
|
||||
|
||||
if n.handle != nil {
|
||||
return ciPath, true
|
||||
} else if fixTrailingSlash && len(n.children) == 1 {
|
||||
// No handle found. Check if a handle for this path + a
|
||||
// trailing slash exists
|
||||
n = n.children[0]
|
||||
if n.path == "/" && n.handle != nil {
|
||||
return append(ciPath, '/'), true
|
||||
}
|
||||
}
|
||||
return ciPath, false
|
||||
|
||||
case catchAll:
|
||||
return append(ciPath, path...), true
|
||||
|
||||
default:
|
||||
panic("invalid node type")
|
||||
}
|
||||
} else {
|
||||
// We should have reached the node containing the handle.
|
||||
// Check if this node has a handle registered.
|
||||
if n.handle != nil {
|
||||
return ciPath, true
|
||||
}
|
||||
|
||||
// No handle found.
|
||||
// Try to fix the path by adding a trailing slash
|
||||
if fixTrailingSlash {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == '/' {
|
||||
n = n.children[i]
|
||||
if (len(n.path) == 1 && n.handle != nil) ||
|
||||
(n.nType == catchAll && n.children[0].handle != nil) {
|
||||
return append(ciPath, '/'), true
|
||||
}
|
||||
return ciPath, false
|
||||
}
|
||||
}
|
||||
}
|
||||
return ciPath, false
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing found.
|
||||
// Try to fix the path by adding / removing a trailing slash
|
||||
if fixTrailingSlash {
|
||||
if path == "/" {
|
||||
return ciPath, true
|
||||
}
|
||||
if len(loPath)+1 == len(loNPath) && loNPath[len(loPath)] == '/' &&
|
||||
loPath[1:] == loNPath[1:len(loPath)] && n.handle != nil {
|
||||
return append(ciPath, n.path...), true
|
||||
}
|
||||
}
|
||||
return ciPath, false
|
||||
}
|
|
@ -0,0 +1,659 @@
|
|||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be found
|
||||
// in the LICENSE file.
|
||||
|
||||
package httprouter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func printChildren(n *node, prefix string) {
|
||||
fmt.Printf(" %02d:%02d %s%s[%d] %v %t %d \r\n", n.priority, n.maxParams, prefix, n.path, len(n.children), n.handle, n.wildChild, n.nType)
|
||||
for l := len(n.path); l > 0; l-- {
|
||||
prefix += " "
|
||||
}
|
||||
for _, child := range n.children {
|
||||
printChildren(child, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Used as a workaround since we can't compare functions or their addresses
|
||||
var fakeHandlerValue string
|
||||
|
||||
func fakeHandler(val string) http.Handler {
|
||||
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
fakeHandlerValue = val
|
||||
})
|
||||
}
|
||||
|
||||
type testRequests []struct {
|
||||
path string
|
||||
nilHandler bool
|
||||
route string
|
||||
ps Params
|
||||
}
|
||||
|
||||
func checkRequests(t *testing.T, tree *node, requests testRequests) {
|
||||
for _, request := range requests {
|
||||
handler, ps, _ := tree.getValue(request.path)
|
||||
|
||||
if handler == nil {
|
||||
if !request.nilHandler {
|
||||
t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path)
|
||||
}
|
||||
} else if request.nilHandler {
|
||||
t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path)
|
||||
} else {
|
||||
handler.ServeHTTP(nil, nil)
|
||||
if fakeHandlerValue != request.route {
|
||||
t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route)
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(ps, request.ps) {
|
||||
t.Errorf("Params mismatch for route '%s'", request.path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkPriorities(t *testing.T, n *node) uint32 {
|
||||
var prio uint32
|
||||
for i := range n.children {
|
||||
prio += checkPriorities(t, n.children[i])
|
||||
}
|
||||
|
||||
if n.handle != nil {
|
||||
prio++
|
||||
}
|
||||
|
||||
if n.priority != prio {
|
||||
t.Errorf(
|
||||
"priority mismatch for node '%s': is %d, should be %d",
|
||||
n.path, n.priority, prio,
|
||||
)
|
||||
}
|
||||
|
||||
return prio
|
||||
}
|
||||
|
||||
func checkMaxParams(t *testing.T, n *node) uint8 {
|
||||
var maxParams uint8
|
||||
for i := range n.children {
|
||||
params := checkMaxParams(t, n.children[i])
|
||||
if params > maxParams {
|
||||
maxParams = params
|
||||
}
|
||||
}
|
||||
if n.nType > root && !n.wildChild {
|
||||
maxParams++
|
||||
}
|
||||
|
||||
if n.maxParams != maxParams {
|
||||
t.Errorf(
|
||||
"maxParams mismatch for node '%s': is %d, should be %d",
|
||||
n.path, n.maxParams, maxParams,
|
||||
)
|
||||
}
|
||||
|
||||
return maxParams
|
||||
}
|
||||
|
||||
func TestCountParams(t *testing.T) {
|
||||
if countParams("/path/:param1/static/*catch-all") != 2 {
|
||||
t.Fail()
|
||||
}
|
||||
if countParams(strings.Repeat("/:param", 256)) != 255 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
func TestTreeAddAndGet(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
routes := [...]string{
|
||||
"/hi",
|
||||
"/contact",
|
||||
"/co",
|
||||
"/c",
|
||||
"/a",
|
||||
"/ab",
|
||||
"/doc/",
|
||||
"/doc/go_faq.html",
|
||||
"/doc/go1.html",
|
||||
"/α",
|
||||
"/β",
|
||||
}
|
||||
for _, route := range routes {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
}
|
||||
|
||||
//printChildren(tree, "")
|
||||
|
||||
checkRequests(t, tree, testRequests{
|
||||
{"/a", false, "/a", nil},
|
||||
{"/", true, "", nil},
|
||||
{"/hi", false, "/hi", nil},
|
||||
{"/contact", false, "/contact", nil},
|
||||
{"/co", false, "/co", nil},
|
||||
{"/con", true, "", nil}, // key mismatch
|
||||
{"/cona", true, "", nil}, // key mismatch
|
||||
{"/no", true, "", nil}, // no matching child
|
||||
{"/ab", false, "/ab", nil},
|
||||
{"/α", false, "/α", nil},
|
||||
{"/β", false, "/β", nil},
|
||||
})
|
||||
|
||||
checkPriorities(t, tree)
|
||||
checkMaxParams(t, tree)
|
||||
}
|
||||
|
||||
func TestTreeWildcard(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
routes := [...]string{
|
||||
"/",
|
||||
"/cmd/:tool/:sub",
|
||||
"/cmd/:tool/",
|
||||
"/src/*filepath",
|
||||
"/search/",
|
||||
"/search/:query",
|
||||
"/user_:name",
|
||||
"/user_:name/about",
|
||||
"/files/:dir/*filepath",
|
||||
"/doc/",
|
||||
"/doc/go_faq.html",
|
||||
"/doc/go1.html",
|
||||
"/info/:user/public",
|
||||
"/info/:user/project/:project",
|
||||
}
|
||||
for _, route := range routes {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
}
|
||||
|
||||
//printChildren(tree, "")
|
||||
|
||||
checkRequests(t, tree, testRequests{
|
||||
{"/", false, "/", nil},
|
||||
{"/cmd/test/", false, "/cmd/:tool/", Params{Param{"tool", "test"}}},
|
||||
{"/cmd/test", true, "", Params{Param{"tool", "test"}}},
|
||||
{"/cmd/test/3", false, "/cmd/:tool/:sub", Params{Param{"tool", "test"}, Param{"sub", "3"}}},
|
||||
{"/src/", false, "/src/*filepath", Params{Param{"filepath", "/"}}},
|
||||
{"/src/some/file.png", false, "/src/*filepath", Params{Param{"filepath", "/some/file.png"}}},
|
||||
{"/search/", false, "/search/", nil},
|
||||
{"/search/someth!ng+in+ünìcodé", false, "/search/:query", Params{Param{"query", "someth!ng+in+ünìcodé"}}},
|
||||
{"/search/someth!ng+in+ünìcodé/", true, "", Params{Param{"query", "someth!ng+in+ünìcodé"}}},
|
||||
{"/user_gopher", false, "/user_:name", Params{Param{"name", "gopher"}}},
|
||||
{"/user_gopher/about", false, "/user_:name/about", Params{Param{"name", "gopher"}}},
|
||||
{"/files/js/inc/framework.js", false, "/files/:dir/*filepath", Params{Param{"dir", "js"}, Param{"filepath", "/inc/framework.js"}}},
|
||||
{"/info/gordon/public", false, "/info/:user/public", Params{Param{"user", "gordon"}}},
|
||||
{"/info/gordon/project/go", false, "/info/:user/project/:project", Params{Param{"user", "gordon"}, Param{"project", "go"}}},
|
||||
})
|
||||
|
||||
checkPriorities(t, tree)
|
||||
checkMaxParams(t, tree)
|
||||
}
|
||||
|
||||
func catchPanic(testFunc func()) (recv interface{}) {
|
||||
defer func() {
|
||||
recv = recover()
|
||||
}()
|
||||
|
||||
testFunc()
|
||||
return
|
||||
}
|
||||
|
||||
type testRoute struct {
|
||||
path string
|
||||
conflict bool
|
||||
}
|
||||
|
||||
func testRoutes(t *testing.T, routes []testRoute) {
|
||||
tree := &node{}
|
||||
|
||||
for _, route := range routes {
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute(route.path, nil)
|
||||
})
|
||||
|
||||
if route.conflict {
|
||||
if recv == nil {
|
||||
t.Errorf("no panic for conflicting route '%s'", route.path)
|
||||
}
|
||||
} else if recv != nil {
|
||||
t.Errorf("unexpected panic for route '%s': %v", route.path, recv)
|
||||
}
|
||||
}
|
||||
|
||||
//printChildren(tree, "")
|
||||
}
|
||||
|
||||
func TestTreeWildcardConflict(t *testing.T) {
|
||||
routes := []testRoute{
|
||||
{"/cmd/:tool/:sub", false},
|
||||
{"/cmd/vet", true},
|
||||
{"/src/*filepath", false},
|
||||
{"/src/*filepathx", true},
|
||||
{"/src/", true},
|
||||
{"/src1/", false},
|
||||
{"/src1/*filepath", true},
|
||||
{"/src2*filepath", true},
|
||||
{"/search/:query", false},
|
||||
{"/search/invalid", true},
|
||||
{"/user_:name", false},
|
||||
{"/user_x", true},
|
||||
{"/user_:name", false},
|
||||
{"/id:id", false},
|
||||
{"/id/:id", true},
|
||||
}
|
||||
testRoutes(t, routes)
|
||||
}
|
||||
|
||||
func TestTreeChildConflict(t *testing.T) {
|
||||
routes := []testRoute{
|
||||
{"/cmd/vet", false},
|
||||
{"/cmd/:tool/:sub", true},
|
||||
{"/src/AUTHORS", false},
|
||||
{"/src/*filepath", true},
|
||||
{"/user_x", false},
|
||||
{"/user_:name", true},
|
||||
{"/id/:id", false},
|
||||
{"/id:id", true},
|
||||
{"/:id", true},
|
||||
{"/*filepath", true},
|
||||
}
|
||||
testRoutes(t, routes)
|
||||
}
|
||||
|
||||
func TestTreeDupliatePath(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
routes := [...]string{
|
||||
"/",
|
||||
"/doc/",
|
||||
"/src/*filepath",
|
||||
"/search/:query",
|
||||
"/user_:name",
|
||||
}
|
||||
for _, route := range routes {
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
})
|
||||
if recv != nil {
|
||||
t.Fatalf("panic inserting route '%s': %v", route, recv)
|
||||
}
|
||||
|
||||
// Add again
|
||||
recv = catchPanic(func() {
|
||||
tree.addRoute(route, nil)
|
||||
})
|
||||
if recv == nil {
|
||||
t.Fatalf("no panic while inserting duplicate route '%s", route)
|
||||
}
|
||||
}
|
||||
|
||||
//printChildren(tree, "")
|
||||
|
||||
checkRequests(t, tree, testRequests{
|
||||
{"/", false, "/", nil},
|
||||
{"/doc/", false, "/doc/", nil},
|
||||
{"/src/some/file.png", false, "/src/*filepath", Params{Param{"filepath", "/some/file.png"}}},
|
||||
{"/search/someth!ng+in+ünìcodé", false, "/search/:query", Params{Param{"query", "someth!ng+in+ünìcodé"}}},
|
||||
{"/user_gopher", false, "/user_:name", Params{Param{"name", "gopher"}}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmptyWildcardName(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
routes := [...]string{
|
||||
"/user:",
|
||||
"/user:/",
|
||||
"/cmd/:/",
|
||||
"/src/*",
|
||||
}
|
||||
for _, route := range routes {
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute(route, nil)
|
||||
})
|
||||
if recv == nil {
|
||||
t.Fatalf("no panic while inserting route with empty wildcard name '%s", route)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTreeCatchAllConflict(t *testing.T) {
|
||||
routes := []testRoute{
|
||||
{"/src/*filepath/x", true},
|
||||
{"/src2/", false},
|
||||
{"/src2/*filepath/x", true},
|
||||
}
|
||||
testRoutes(t, routes)
|
||||
}
|
||||
|
||||
func TestTreeCatchAllConflictRoot(t *testing.T) {
|
||||
routes := []testRoute{
|
||||
{"/", false},
|
||||
{"/*filepath", true},
|
||||
}
|
||||
testRoutes(t, routes)
|
||||
}
|
||||
|
||||
func TestTreeDoubleWildcard(t *testing.T) {
|
||||
const panicMsg = "only one wildcard per path segment is allowed"
|
||||
|
||||
routes := [...]string{
|
||||
"/:foo:bar",
|
||||
"/:foo:bar/",
|
||||
"/:foo*bar",
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
tree := &node{}
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute(route, nil)
|
||||
})
|
||||
|
||||
if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) {
|
||||
t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*func TestTreeDuplicateWildcard(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
routes := [...]string{
|
||||
"/:id/:name/:id",
|
||||
}
|
||||
for _, route := range routes {
|
||||
...
|
||||
}
|
||||
}*/
|
||||
|
||||
func TestTreeTrailingSlashRedirect(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
routes := [...]string{
|
||||
"/hi",
|
||||
"/b/",
|
||||
"/search/:query",
|
||||
"/cmd/:tool/",
|
||||
"/src/*filepath",
|
||||
"/x",
|
||||
"/x/y",
|
||||
"/y/",
|
||||
"/y/z",
|
||||
"/0/:id",
|
||||
"/0/:id/1",
|
||||
"/1/:id/",
|
||||
"/1/:id/2",
|
||||
"/aa",
|
||||
"/a/",
|
||||
"/admin",
|
||||
"/admin/:category",
|
||||
"/admin/:category/:page",
|
||||
"/doc",
|
||||
"/doc/go_faq.html",
|
||||
"/doc/go1.html",
|
||||
"/no/a",
|
||||
"/no/b",
|
||||
"/api/hello/:name",
|
||||
}
|
||||
for _, route := range routes {
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
})
|
||||
if recv != nil {
|
||||
t.Fatalf("panic inserting route '%s': %v", route, recv)
|
||||
}
|
||||
}
|
||||
|
||||
//printChildren(tree, "")
|
||||
|
||||
tsrRoutes := [...]string{
|
||||
"/hi/",
|
||||
"/b",
|
||||
"/search/gopher/",
|
||||
"/cmd/vet",
|
||||
"/src",
|
||||
"/x/",
|
||||
"/y",
|
||||
"/0/go/",
|
||||
"/1/go",
|
||||
"/a",
|
||||
"/admin/",
|
||||
"/admin/config/",
|
||||
"/admin/config/permissions/",
|
||||
"/doc/",
|
||||
}
|
||||
for _, route := range tsrRoutes {
|
||||
handler, _, tsr := tree.getValue(route)
|
||||
if handler != nil {
|
||||
t.Fatalf("non-nil handler for TSR route '%s", route)
|
||||
} else if !tsr {
|
||||
t.Errorf("expected TSR recommendation for route '%s'", route)
|
||||
}
|
||||
}
|
||||
|
||||
noTsrRoutes := [...]string{
|
||||
"/",
|
||||
"/no",
|
||||
"/no/",
|
||||
"/_",
|
||||
"/_/",
|
||||
"/api/world/abc",
|
||||
}
|
||||
for _, route := range noTsrRoutes {
|
||||
handler, _, tsr := tree.getValue(route)
|
||||
if handler != nil {
|
||||
t.Fatalf("non-nil handler for No-TSR route '%s", route)
|
||||
} else if tsr {
|
||||
t.Errorf("expected no TSR recommendation for route '%s'", route)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTreeRootTrailingSlashRedirect(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute("/:test", fakeHandler("/:test"))
|
||||
})
|
||||
if recv != nil {
|
||||
t.Fatalf("panic inserting test route: %v", recv)
|
||||
}
|
||||
|
||||
handler, _, tsr := tree.getValue("/")
|
||||
if handler != nil {
|
||||
t.Fatalf("non-nil handler")
|
||||
} else if tsr {
|
||||
t.Errorf("expected no TSR recommendation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTreeFindCaseInsensitivePath(t *testing.T) {
|
||||
tree := &node{}
|
||||
|
||||
routes := [...]string{
|
||||
"/hi",
|
||||
"/b/",
|
||||
"/ABC/",
|
||||
"/search/:query",
|
||||
"/cmd/:tool/",
|
||||
"/src/*filepath",
|
||||
"/x",
|
||||
"/x/y",
|
||||
"/y/",
|
||||
"/y/z",
|
||||
"/0/:id",
|
||||
"/0/:id/1",
|
||||
"/1/:id/",
|
||||
"/1/:id/2",
|
||||
"/aa",
|
||||
"/a/",
|
||||
"/doc",
|
||||
"/doc/go_faq.html",
|
||||
"/doc/go1.html",
|
||||
"/doc/go/away",
|
||||
"/no/a",
|
||||
"/no/b",
|
||||
"/Π",
|
||||
"/u/apfêl/",
|
||||
"/u/äpfêl/",
|
||||
"/u/öpfêl",
|
||||
"/v/Äpfêl/",
|
||||
"/v/Öpfêl",
|
||||
"/w/♬", // 3 byte
|
||||
"/w/♭/", // 3 byte, last byte differs
|
||||
"/w/𠜎", // 4 byte
|
||||
"/w/𠜏/", // 4 byte
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
recv := catchPanic(func() {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
})
|
||||
if recv != nil {
|
||||
t.Fatalf("panic inserting route '%s': %v", route, recv)
|
||||
}
|
||||
}
|
||||
|
||||
// Check out == in for all registered routes
|
||||
// With fixTrailingSlash = true
|
||||
for _, route := range routes {
|
||||
out, found := tree.findCaseInsensitivePath(route, true)
|
||||
if !found {
|
||||
t.Errorf("Route '%s' not found!", route)
|
||||
} else if string(out) != route {
|
||||
t.Errorf("Wrong result for route '%s': %s", route, string(out))
|
||||
}
|
||||
}
|
||||
// With fixTrailingSlash = false
|
||||
for _, route := range routes {
|
||||
out, found := tree.findCaseInsensitivePath(route, false)
|
||||
if !found {
|
||||
t.Errorf("Route '%s' not found!", route)
|
||||
} else if string(out) != route {
|
||||
t.Errorf("Wrong result for route '%s': %s", route, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
in string
|
||||
out string
|
||||
found bool
|
||||
slash bool
|
||||
}{
|
||||
{"/HI", "/hi", true, false},
|
||||
{"/HI/", "/hi", true, true},
|
||||
{"/B", "/b/", true, true},
|
||||
{"/B/", "/b/", true, false},
|
||||
{"/abc", "/ABC/", true, true},
|
||||
{"/abc/", "/ABC/", true, false},
|
||||
{"/aBc", "/ABC/", true, true},
|
||||
{"/aBc/", "/ABC/", true, false},
|
||||
{"/abC", "/ABC/", true, true},
|
||||
{"/abC/", "/ABC/", true, false},
|
||||
{"/SEARCH/QUERY", "/search/QUERY", true, false},
|
||||
{"/SEARCH/QUERY/", "/search/QUERY", true, true},
|
||||
{"/CMD/TOOL/", "/cmd/TOOL/", true, false},
|
||||
{"/CMD/TOOL", "/cmd/TOOL/", true, true},
|
||||
{"/SRC/FILE/PATH", "/src/FILE/PATH", true, false},
|
||||
{"/x/Y", "/x/y", true, false},
|
||||
{"/x/Y/", "/x/y", true, true},
|
||||
{"/X/y", "/x/y", true, false},
|
||||
{"/X/y/", "/x/y", true, true},
|
||||
{"/X/Y", "/x/y", true, false},
|
||||
{"/X/Y/", "/x/y", true, true},
|
||||
{"/Y/", "/y/", true, false},
|
||||
{"/Y", "/y/", true, true},
|
||||
{"/Y/z", "/y/z", true, false},
|
||||
{"/Y/z/", "/y/z", true, true},
|
||||
{"/Y/Z", "/y/z", true, false},
|
||||
{"/Y/Z/", "/y/z", true, true},
|
||||
{"/y/Z", "/y/z", true, false},
|
||||
{"/y/Z/", "/y/z", true, true},
|
||||
{"/Aa", "/aa", true, false},
|
||||
{"/Aa/", "/aa", true, true},
|
||||
{"/AA", "/aa", true, false},
|
||||
{"/AA/", "/aa", true, true},
|
||||
{"/aA", "/aa", true, false},
|
||||
{"/aA/", "/aa", true, true},
|
||||
{"/A/", "/a/", true, false},
|
||||
{"/A", "/a/", true, true},
|
||||
{"/DOC", "/doc", true, false},
|
||||
{"/DOC/", "/doc", true, true},
|
||||
{"/NO", "", false, true},
|
||||
{"/DOC/GO", "", false, true},
|
||||
{"/π", "/Π", true, false},
|
||||
{"/π/", "/Π", true, true},
|
||||
{"/u/ÄPFÊL/", "/u/äpfêl/", true, false},
|
||||
{"/u/ÄPFÊL", "/u/äpfêl/", true, true},
|
||||
{"/u/ÖPFÊL/", "/u/öpfêl", true, true},
|
||||
{"/u/ÖPFÊL", "/u/öpfêl", true, false},
|
||||
{"/v/äpfêL/", "/v/Äpfêl/", true, false},
|
||||
{"/v/äpfêL", "/v/Äpfêl/", true, true},
|
||||
{"/v/öpfêL/", "/v/Öpfêl", true, true},
|
||||
{"/v/öpfêL", "/v/Öpfêl", true, false},
|
||||
{"/w/♬/", "/w/♬", true, true},
|
||||
{"/w/♭", "/w/♭/", true, true},
|
||||
{"/w/𠜎/", "/w/𠜎", true, true},
|
||||
{"/w/𠜏", "/w/𠜏/", true, true},
|
||||
}
|
||||
// With fixTrailingSlash = true
|
||||
for _, test := range tests {
|
||||
out, found := tree.findCaseInsensitivePath(test.in, true)
|
||||
if found != test.found || (found && (string(out) != test.out)) {
|
||||
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
|
||||
test.in, string(out), found, test.out, test.found)
|
||||
return
|
||||
}
|
||||
}
|
||||
// With fixTrailingSlash = false
|
||||
for _, test := range tests {
|
||||
out, found := tree.findCaseInsensitivePath(test.in, false)
|
||||
if test.slash {
|
||||
if found { // test needs a trailingSlash fix. It must not be found!
|
||||
t.Errorf("Found without fixTrailingSlash: %s; got %s", test.in, string(out))
|
||||
}
|
||||
} else {
|
||||
if found != test.found || (found && (string(out) != test.out)) {
|
||||
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
|
||||
test.in, string(out), found, test.out, test.found)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTreeInvalidNodeType(t *testing.T) {
|
||||
const panicMsg = "invalid node type"
|
||||
|
||||
tree := &node{}
|
||||
tree.addRoute("/", fakeHandler("/"))
|
||||
tree.addRoute("/:page", fakeHandler("/:page"))
|
||||
|
||||
// set invalid node type
|
||||
tree.children[0].nType = 42
|
||||
|
||||
// normal lookup
|
||||
recv := catchPanic(func() {
|
||||
tree.getValue("/test")
|
||||
})
|
||||
if rs, ok := recv.(string); !ok || rs != panicMsg {
|
||||
t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
|
||||
}
|
||||
|
||||
// case-insensitive lookup
|
||||
recv = catchPanic(func() {
|
||||
tree.findCaseInsensitivePath("/test", true)
|
||||
})
|
||||
if rs, ok := recv.(string); !ok || rs != panicMsg {
|
||||
t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
.DS_Store
|
||||
bin
|
||||
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.3
|
||||
- 1.4
|
||||
- 1.5
|
||||
- 1.6
|
||||
- tip
|
|
@ -0,0 +1,8 @@
|
|||
Copyright (c) 2012 Dave Grijalva
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
## Migration Guide from v2 -> v3
|
||||
|
||||
Version 3 adds several new, frequently requested features. To do so, it introduces a few breaking changes. We've worked to keep these as minimal as possible. This guide explains the breaking changes and how you can quickly update your code.
|
||||
|
||||
### `Token.Claims` is now an interface type
|
||||
|
||||
The most requested feature from the 2.0 verison of this library was the ability to provide a custom type to the JSON parser for claims. This was implemented by introducing a new interface, `Claims`, to replace `map[string]interface{}`. We also included two concrete implementations of `Claims`: `MapClaims` and `StandardClaims`.
|
||||
|
||||
`MapClaims` is an alias for `map[string]interface{}` with built in validation behavior. It is the default claims type when using `Parse`. The usage is unchanged except you must type cast the claims property.
|
||||
|
||||
The old example for parsing a token looked like this..
|
||||
|
||||
```go
|
||||
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
|
||||
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
is now directly mapped to...
|
||||
|
||||
```go
|
||||
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
`StandardClaims` is designed to be embedded in your custom type. You can supply a custom claims type with the new `ParseWithClaims` function. Here's an example of using a custom claims type.
|
||||
|
||||
```go
|
||||
type MyCustomClaims struct {
|
||||
User string
|
||||
*StandardClaims
|
||||
}
|
||||
|
||||
if token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, keyLookupFunc); err == nil {
|
||||
claims := token.Claims.(*MyCustomClaims)
|
||||
fmt.Printf("Token for user %v expires %v", claims.User, claims.StandardClaims.ExpiresAt)
|
||||
}
|
||||
```
|
||||
|
||||
### `ParseFromRequest` has been moved
|
||||
|
||||
To keep this library focused on the tokens without becoming overburdened with complex request processing logic, `ParseFromRequest` and its new companion `ParseFromRequestWithClaims` have been moved to a subpackage, `request`. The method signatues have also been augmented to receive a new argument: `Extractor`.
|
||||
|
||||
`Extractors` do the work of picking the token string out of a request. The interface is simple and composable.
|
||||
|
||||
This simple parsing example:
|
||||
|
||||
```go
|
||||
if token, err := jwt.ParseFromRequest(tokenString, req, keyLookupFunc); err == nil {
|
||||
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
is directly mapped to:
|
||||
|
||||
```go
|
||||
if token, err := request.ParseFromRequest(req, request.OAuth2Extractor, keyLookupFunc); err == nil {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
There are several concrete `Extractor` types provided for your convenience:
|
||||
|
||||
* `HeaderExtractor` will search a list of headers until one contains content.
|
||||
* `ArgumentExtractor` will search a list of keys in request query and form arguments until one contains content.
|
||||
* `MultiExtractor` will try a list of `Extractors` in order until one returns content.
|
||||
* `AuthorizationHeaderExtractor` will look in the `Authorization` header for a `Bearer` token.
|
||||
* `OAuth2Extractor` searches the places an OAuth2 token would be specified (per the spec): `Authorization` header and `access_token` argument
|
||||
* `PostExtractionFilter` wraps an `Extractor`, allowing you to process the content before it's parsed. A simple example is stripping the `Bearer ` text from a header
|
||||
|
||||
|
||||
### RSA signing methods no longer accept `[]byte` keys
|
||||
|
||||
Due to a [critical vulnerability](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/), we've decided the convenience of accepting `[]byte` instead of `rsa.PublicKey` or `rsa.PrivateKey` isn't worth the risk of misuse.
|
||||
|
||||
To replace this behavior, we've added two helper methods: `ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error)` and `ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error)`. These are just simple helpers for unpacking PEM encoded PKCS1 and PKCS8 keys. If your keys are encoded any other way, all you need to do is convert them to the `crypto/rsa` package's types.
|
||||
|
||||
```go
|
||||
func keyLookupFunc(*Token) (interface{}, error) {
|
||||
// Don't forget to validate the alg is what you expect:
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// Look up key
|
||||
key, err := lookupPublicKey(token.Header["kid"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unpack key from PEM encoded PKCS8
|
||||
return jwt.ParseRSAPublicKeyFromPEM(key)
|
||||
}
|
||||
```
|
|
@ -0,0 +1,85 @@
|
|||
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
|
||||
|
||||
[![Build Status](https://travis-ci.org/dgrijalva/jwt-go.svg?branch=master)](https://travis-ci.org/dgrijalva/jwt-go)
|
||||
|
||||
**BREAKING CHANGES:*** Version 3.0.0 is here. It includes _a lot_ of changes including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
|
||||
|
||||
**NOTICE:** A vulnerability in JWT was [recently published](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). As this library doesn't force users to validate the `alg` is what they expected, it's possible your usage is effected. There will be an update soon to remedy this, and it will likey require backwards-incompatible changes to the API. In the short term, please make sure your implementation verifies the `alg` is what you expect.
|
||||
|
||||
|
||||
## What the heck is a JWT?
|
||||
|
||||
JWT.io has [a great introduction](https://jwt.io/introduction) to JSON Web Tokens.
|
||||
|
||||
In short, it's a signed JSON object that does something useful (for example, authentication). It's commonly used for `Bearer` tokens in Oauth 2. A token is made of three parts, separated by `.`'s. The first two parts are JSON objects, that have been [base64url](http://tools.ietf.org/html/rfc4648) encoded. The last part is the signature, encoded the same way.
|
||||
|
||||
The first part is called the header. It contains the necessary information for verifying the last part, the signature. For example, which encryption method was used for signing and what key was used.
|
||||
|
||||
The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-jones-json-web-token.html) for information about reserved keys and the proper way to add your own.
|
||||
|
||||
## What's in the box?
|
||||
|
||||
This library supports the parsing and verification as well as the generation and signing of JWTs. Current supported signing algorithms are HMAC SHA, RSA, RSA-PSS, and ECDSA, though hooks are present for adding your own.
|
||||
|
||||
## Examples
|
||||
|
||||
See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) for examples of usage:
|
||||
|
||||
* [Simple example of parsing and validating a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-Parse--Hmac)
|
||||
* [Simple example of building and signing a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-New--Hmac)
|
||||
* [Directory of Examples](https://godoc.org/github.com/dgrijalva/jwt-go#pkg-examples)
|
||||
|
||||
## Extensions
|
||||
|
||||
This library publishes all the necessary components for adding your own signing methods. Simply implement the `SigningMethod` interface and register a factory method using `RegisterSigningMethod`.
|
||||
|
||||
Here's an example of an extension that integrates with the Google App Engine signing tools: https://github.com/someone1/gcp-jwt-go
|
||||
|
||||
## Compliance
|
||||
|
||||
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
|
||||
|
||||
* In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.
|
||||
|
||||
## Project Status & Versioning
|
||||
|
||||
This library is considered production ready. Feedback and feature requests are appreciated. The API should be considered stable. There should be very few backwards-incompatible changes outside of major version updates (and only with good reason).
|
||||
|
||||
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
|
||||
|
||||
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v2`. It will do the right thing WRT semantic versioning.
|
||||
|
||||
## Usage Tips
|
||||
|
||||
### Signing vs Encryption
|
||||
|
||||
A token is simply a JSON object that is signed by its author. this tells you exactly two things about the data:
|
||||
|
||||
* The author of the token was in the possession of the signing secret
|
||||
* The data has not been modified since it was signed
|
||||
|
||||
It's important to know that JWT does not provide encryption, which means anyone who has access to the token can read its contents. If you need to protect (encrypt) the data, there is a companion spec, `JWE`, that provides this functionality. JWE is currently outside the scope of this library.
|
||||
|
||||
### Choosing a Signing Method
|
||||
|
||||
There are several signing methods available, and you should probably take the time to learn about the various options before choosing one. The principal design decision is most likely going to be symmetric vs asymmetric.
|
||||
|
||||
Symmetric signing methods, such as HSA, use only a single secret. This is probably the simplest signing method to use since any `[]byte` can be used as a valid secret. They are also slightly computationally faster to use, though this rarely is enough to matter. Symmetric signing methods work the best when both producers and consumers of tokens are trusted, or even the same system. Since the same secret is used to both sign and validate tokens, you can't easily distribute the key for validation.
|
||||
|
||||
Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification.
|
||||
|
||||
### JWT and OAuth
|
||||
|
||||
It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication.
|
||||
|
||||
Without going too far down the rabbit hole, here's a description of the interaction of these technologies:
|
||||
|
||||
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
|
||||
* OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token.
|
||||
* Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL.
|
||||
|
||||
## More
|
||||
|
||||
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
|
||||
|
||||
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in to documentation.
|
|
@ -0,0 +1,105 @@
|
|||
## `jwt-go` Version History
|
||||
|
||||
#### 3.0.0
|
||||
|
||||
* **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code
|
||||
* Dropped support for `[]byte` keys when using RSA signing methods. This convenience feature could contribute to security vulnerabilities involving mismatched key types with signing methods.
|
||||
* `ParseFromRequest` has been moved to `request` subpackage and usage has changed
|
||||
* The `Claims` property on `Token` is now type `Claims` instead of `map[string]interface{}`. The default value is type `MapClaims`, which is an alias to `map[string]interface{}`. This makes it possible to use a custom type when decoding claims.
|
||||
* Other Additions and Changes
|
||||
* Added `Claims` interface type to allow users to decode the claims into a custom type
|
||||
* Added `ParseWithClaims`, which takes a third argument of type `Claims`. Use this function instead of `Parse` if you have a custom type you'd like to decode into.
|
||||
* Dramatically improved the functionality and flexibility of `ParseFromRequest`, which is now in the `request` subpackage
|
||||
* Added `ParseFromRequestWithClaims` which is the `FromRequest` equivalent of `ParseWithClaims`
|
||||
* Added new interface type `Extractor`, which is used for extracting JWT strings from http requests. Used with `ParseFromRequest` and `ParseFromRequestWithClaims`.
|
||||
* Added several new, more specific, validation errors to error type bitmask
|
||||
* Moved examples from README to executable example files
|
||||
* Signing method registry is now thread safe
|
||||
* Added new property to `ValidationError`, which contains the raw error returned by calls made by parse/verify (such as those returned by keyfunc or json parser)
|
||||
|
||||
#### 2.7.0
|
||||
|
||||
This will likely be the last backwards compatible release before 3.0.0, excluding essential bug fixes.
|
||||
|
||||
* Added new option `-show` to the `jwt` command that will just output the decoded token without verifying
|
||||
* Error text for expired tokens includes how long it's been expired
|
||||
* Fixed incorrect error returned from `ParseRSAPublicKeyFromPEM`
|
||||
* Documentation updates
|
||||
|
||||
#### 2.6.0
|
||||
|
||||
* Exposed inner error within ValidationError
|
||||
* Fixed validation errors when using UseJSONNumber flag
|
||||
* Added several unit tests
|
||||
|
||||
#### 2.5.0
|
||||
|
||||
* Added support for signing method none. You shouldn't use this. The API tries to make this clear.
|
||||
* Updated/fixed some documentation
|
||||
* Added more helpful error message when trying to parse tokens that begin with `BEARER `
|
||||
|
||||
#### 2.4.0
|
||||
|
||||
* Added new type, Parser, to allow for configuration of various parsing parameters
|
||||
* You can now specify a list of valid signing methods. Anything outside this set will be rejected.
|
||||
* You can now opt to use the `json.Number` type instead of `float64` when parsing token JSON
|
||||
* Added support for [Travis CI](https://travis-ci.org/dgrijalva/jwt-go)
|
||||
* Fixed some bugs with ECDSA parsing
|
||||
|
||||
#### 2.3.0
|
||||
|
||||
* Added support for ECDSA signing methods
|
||||
* Added support for RSA PSS signing methods (requires go v1.4)
|
||||
|
||||
#### 2.2.0
|
||||
|
||||
* Gracefully handle a `nil` `Keyfunc` being passed to `Parse`. Result will now be the parsed token and an error, instead of a panic.
|
||||
|
||||
#### 2.1.0
|
||||
|
||||
Backwards compatible API change that was missed in 2.0.0.
|
||||
|
||||
* The `SignedString` method on `Token` now takes `interface{}` instead of `[]byte`
|
||||
|
||||
#### 2.0.0
|
||||
|
||||
There were two major reasons for breaking backwards compatibility with this update. The first was a refactor required to expand the width of the RSA and HMAC-SHA signing implementations. There will likely be no required code changes to support this change.
|
||||
|
||||
The second update, while unfortunately requiring a small change in integration, is required to open up this library to other signing methods. Not all keys used for all signing methods have a single standard on-disk representation. Requiring `[]byte` as the type for all keys proved too limiting. Additionally, this implementation allows for pre-parsed tokens to be reused, which might matter in an application that parses a high volume of tokens with a small set of keys. Backwards compatibilty has been maintained for passing `[]byte` to the RSA signing methods, but they will also accept `*rsa.PublicKey` and `*rsa.PrivateKey`.
|
||||
|
||||
It is likely the only integration change required here will be to change `func(t *jwt.Token) ([]byte, error)` to `func(t *jwt.Token) (interface{}, error)` when calling `Parse`.
|
||||
|
||||
* **Compatibility Breaking Changes**
|
||||
* `SigningMethodHS256` is now `*SigningMethodHMAC` instead of `type struct`
|
||||
* `SigningMethodRS256` is now `*SigningMethodRSA` instead of `type struct`
|
||||
* `KeyFunc` now returns `interface{}` instead of `[]byte`
|
||||
* `SigningMethod.Sign` now takes `interface{}` instead of `[]byte` for the key
|
||||
* `SigningMethod.Verify` now takes `interface{}` instead of `[]byte` for the key
|
||||
* Renamed type `SigningMethodHS256` to `SigningMethodHMAC`. Specific sizes are now just instances of this type.
|
||||
* Added public package global `SigningMethodHS256`
|
||||
* Added public package global `SigningMethodHS384`
|
||||
* Added public package global `SigningMethodHS512`
|
||||
* Renamed type `SigningMethodRS256` to `SigningMethodRSA`. Specific sizes are now just instances of this type.
|
||||
* Added public package global `SigningMethodRS256`
|
||||
* Added public package global `SigningMethodRS384`
|
||||
* Added public package global `SigningMethodRS512`
|
||||
* Moved sample private key for HMAC tests from an inline value to a file on disk. Value is unchanged.
|
||||
* Refactored the RSA implementation to be easier to read
|
||||
* Exposed helper methods `ParseRSAPrivateKeyFromPEM` and `ParseRSAPublicKeyFromPEM`
|
||||
|
||||
#### 1.0.2
|
||||
|
||||
* Fixed bug in parsing public keys from certificates
|
||||
* Added more tests around the parsing of keys for RS256
|
||||
* Code refactoring in RS256 implementation. No functional changes
|
||||
|
||||
#### 1.0.1
|
||||
|
||||
* Fixed panic if RS256 signing method was passed an invalid key
|
||||
|
||||
#### 1.0.0
|
||||
|
||||
* First versioned release
|
||||
* API stabilized
|
||||
* Supports creating, signing, parsing, and validating JWT tokens
|
||||
* Supports RS256 and HS256 signing methods
|
|
@ -0,0 +1,134 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// For a type to be a Claims object, it must just have a Valid method that determines
|
||||
// if the token is invalid for any supported reason
|
||||
type Claims interface {
|
||||
Valid() error
|
||||
}
|
||||
|
||||
// Structured version of Claims Section, as referenced at
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1
|
||||
// See examples for how to use this with your own claim types
|
||||
type StandardClaims struct {
|
||||
Audience string `json:"aud,omitempty"`
|
||||
ExpiresAt int64 `json:"exp,omitempty"`
|
||||
Id string `json:"jti,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
}
|
||||
|
||||
// Validates time based claims "exp, iat, nbf".
|
||||
// There is no accounting for clock skew.
|
||||
// As well, if any of the above claims are not in the token, it will still
|
||||
// be considered a valid claim.
|
||||
func (c StandardClaims) Valid() error {
|
||||
vErr := new(ValidationError)
|
||||
now := TimeFunc().Unix()
|
||||
|
||||
// The claims below are optional, by default, so if they are set to the
|
||||
// default value in Go, let's not fail the verification for them.
|
||||
if c.VerifyExpiresAt(now, false) == false {
|
||||
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
|
||||
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
|
||||
vErr.Errors |= ValidationErrorExpired
|
||||
}
|
||||
|
||||
if c.VerifyIssuedAt(now, false) == false {
|
||||
vErr.Inner = fmt.Errorf("Token used before issued")
|
||||
vErr.Errors |= ValidationErrorIssuedAt
|
||||
}
|
||||
|
||||
if c.VerifyNotBefore(now, false) == false {
|
||||
vErr.Inner = fmt.Errorf("token is not valid yet")
|
||||
vErr.Errors |= ValidationErrorNotValidYet
|
||||
}
|
||||
|
||||
if vErr.valid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return vErr
|
||||
}
|
||||
|
||||
// Compares the aud claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool {
|
||||
return verifyAud(c.Audience, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the exp claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool {
|
||||
return verifyExp(c.ExpiresAt, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the iat claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool {
|
||||
return verifyIat(c.IssuedAt, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the iss claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool {
|
||||
return verifyIss(c.Issuer, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the nbf claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
|
||||
return verifyNbf(c.NotBefore, cmp, req)
|
||||
}
|
||||
|
||||
// ----- helpers
|
||||
|
||||
func verifyAud(aud string, cmp string, required bool) bool {
|
||||
if aud == "" {
|
||||
return !required
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func verifyExp(exp int64, now int64, required bool) bool {
|
||||
if exp == 0 {
|
||||
return !required
|
||||
}
|
||||
return now <= exp
|
||||
}
|
||||
|
||||
func verifyIat(iat int64, now int64, required bool) bool {
|
||||
if iat == 0 {
|
||||
return !required
|
||||
}
|
||||
return now >= iat
|
||||
}
|
||||
|
||||
func verifyIss(iss string, cmp string, required bool) bool {
|
||||
if iss == "" {
|
||||
return !required
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func verifyNbf(nbf int64, now int64, required bool) bool {
|
||||
if nbf == 0 {
|
||||
return !required
|
||||
}
|
||||
return now >= nbf
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
// Package jwt is a Go implementation of JSON Web Tokens: http://self-issued.info/docs/draft-jones-json-web-token.html
|
||||
//
|
||||
// See README.md for more info.
|
||||
package jwt
|
|
@ -0,0 +1,147 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
var (
|
||||
// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
|
||||
ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
|
||||
)
|
||||
|
||||
// Implements the ECDSA family of signing methods signing methods
|
||||
type SigningMethodECDSA struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
KeySize int
|
||||
CurveBits int
|
||||
}
|
||||
|
||||
// Specific instances for EC256 and company
|
||||
var (
|
||||
SigningMethodES256 *SigningMethodECDSA
|
||||
SigningMethodES384 *SigningMethodECDSA
|
||||
SigningMethodES512 *SigningMethodECDSA
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ES256
|
||||
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
|
||||
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
|
||||
return SigningMethodES256
|
||||
})
|
||||
|
||||
// ES384
|
||||
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
|
||||
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
|
||||
return SigningMethodES384
|
||||
})
|
||||
|
||||
// ES512
|
||||
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
|
||||
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
|
||||
return SigningMethodES512
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SigningMethodECDSA) Alg() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
// Implements the Verify method from SigningMethod
|
||||
// For this verify method, key must be an ecdsa.PublicKey struct
|
||||
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
var sig []byte
|
||||
if sig, err = DecodeSegment(signature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the key
|
||||
var ecdsaKey *ecdsa.PublicKey
|
||||
switch k := key.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
ecdsaKey = k
|
||||
default:
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
|
||||
if len(sig) != 2*m.KeySize {
|
||||
return ErrECDSAVerification
|
||||
}
|
||||
|
||||
r := big.NewInt(0).SetBytes(sig[:m.KeySize])
|
||||
s := big.NewInt(0).SetBytes(sig[m.KeySize:])
|
||||
|
||||
// Create hasher
|
||||
if !m.Hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
}
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Verify the signature
|
||||
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
|
||||
return nil
|
||||
} else {
|
||||
return ErrECDSAVerification
|
||||
}
|
||||
}
|
||||
|
||||
// Implements the Sign method from SigningMethod
|
||||
// For this signing method, key must be an ecdsa.PrivateKey struct
|
||||
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
|
||||
// Get the key
|
||||
var ecdsaKey *ecdsa.PrivateKey
|
||||
switch k := key.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
ecdsaKey = k
|
||||
default:
|
||||
return "", ErrInvalidKeyType
|
||||
}
|
||||
|
||||
// Create the hasher
|
||||
if !m.Hash.Available() {
|
||||
return "", ErrHashUnavailable
|
||||
}
|
||||
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Sign the string and return r, s
|
||||
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
|
||||
curveBits := ecdsaKey.Curve.Params().BitSize
|
||||
|
||||
if m.CurveBits != curveBits {
|
||||
return "", ErrInvalidKey
|
||||
}
|
||||
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes += 1
|
||||
}
|
||||
|
||||
// We serialize the outpus (r and s) into big-endian byte arrays and pad
|
||||
// them with zeros on the left to make sure the sizes work out. Both arrays
|
||||
// must be keyBytes long, and the output must be 2*keyBytes long.
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
|
||||
out := append(rBytesPadded, sBytesPadded...)
|
||||
|
||||
return EncodeSegment(out), nil
|
||||
} else {
|
||||
return "", err
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue