fix(nix): update vendorHash and vendor dir for new deps

This commit is contained in:
Alexander
2026-04-10 18:25:19 +02:00
parent da59d8f83b
commit 9dc664a3ba
2533 changed files with 1304328 additions and 2 deletions
+201
View File
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+42
View File
@@ -0,0 +1,42 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// arrayCodec is the Codec used for bsoncore.Array values.
type arrayCodec struct{}
// EncodeValue is the ValueEncoder for bsoncore.Array values.
func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreArray {
return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
arr := val.Interface().(bsoncore.Array)
return copyArrayFromBytes(vw, arr)
}
// DecodeValue is the ValueDecoder for bsoncore.Array values.
func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tCoreArray {
return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
arr, err := appendArrayBytes(val.Interface().(bsoncore.Array), vr)
val.Set(reflect.ValueOf(arr))
return err
}
+203
View File
@@ -0,0 +1,203 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
"strings"
)
var emptyValue = reflect.Value{}
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
// encoded by the ValueEncoder.
type ValueEncoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vee ValueEncoderError) Error() string {
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
for _, t := range vee.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vee.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vee.Received.Kind().String()
if vee.Received.IsValid() {
received = vee.Received.Type().String()
}
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
}
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
// decoded by the ValueDecoder.
type ValueDecoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vde ValueDecoderError) Error() string {
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
for _, t := range vde.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vde.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vde.Received.Kind().String()
if vde.Received.IsValid() {
received = vde.Received.Type().String()
}
if !vde.Received.CanSet() {
received = "unsettable " + received
}
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
}
// EncodeContext is the contextual information required for a Codec to encode a
// value.
type EncodeContext struct {
*Registry
// minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits)
// that can represent the integer value.
minSize bool
errorOnInlineDuplicates bool
stringifyMapKeysWithFmt bool
nilMapAsEmpty bool
nilSliceAsEmpty bool
nilByteSliceAsEmpty bool
omitZeroStruct bool
omitEmpty bool
useJSONStructTags bool
}
// DecodeContext is the contextual information required for a Codec to decode a
// value.
type DecodeContext struct {
*Registry
// truncate, if true, instructs decoders to to truncate the fractional part of BSON "double"
// values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to
// BSON "decimal128" values.
truncate bool
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
// usage for this field is restricted to data typed as "any" or "map[string]any". If DocumentType is
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
// error.
defaultDocumentType reflect.Type
binaryAsSlice bool
// a false value results in a decoding error.
objectIDAsHexString bool
useJSONStructTags bool
useLocalTimeZone bool
zeroMaps bool
zeroStructs bool
}
// ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON.
// The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
// to provide configuration information.
type ValueEncoder interface {
EncodeValue(EncodeContext, ValueWriter, reflect.Value) error
}
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueEncoder.
type ValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error
// EncodeValue implements the ValueEncoder interface.
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
return fn(ec, vw, val)
}
// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type.
// Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc,
// ValueDecoderFunc is provided to allow the use of a function with the correct signature as a
// ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the
// EncodeContext.
type ValueDecoder interface {
DecodeValue(DecodeContext, ValueReader, reflect.Value) error
}
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueDecoder.
type ValueDecoderFunc func(DecodeContext, ValueReader, reflect.Value) error
// DecodeValue implements the ValueDecoder interface.
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
return fn(dc, vr, val)
}
// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type.
type typeDecoder interface {
decodeType(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
}
// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder.
type typeDecoderFunc func(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
return fn(dc, vr, t)
}
// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder.
type decodeAdapter struct {
ValueDecoderFunc
typeDecoderFunc
}
var (
_ ValueDecoder = decodeAdapter{}
_ typeDecoder = decodeAdapter{}
)
func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if td, _ := vd.(typeDecoder); td != nil {
val, err := td.decodeType(dc, vr, t)
if err == nil && val.Type() != t {
// This conversion step is necessary for slices and maps. If a user declares variables like:
//
// type myBool bool
// var m map[string]myBool
//
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
// because we'll try to assign a value of type bool to one of type myBool.
val = val.Convert(t)
}
return val, err
}
val := reflect.New(t).Elem()
err := vd.DecodeValue(dc, vr, val)
return val, err
}
+128
View File
@@ -0,0 +1,128 @@
// Copyright (C) MongoDB, Inc. 2025-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"io"
)
// bufferedByteSrc implements the low-level byteSrc interface by reading
// directly from an in-memory byte slice. It provides efficient, zero-copy
// access for parsing BSON when the entire document is buffered in memory.
type bufferedByteSrc struct {
buf []byte // entire BSON document
offset int64 // Current read index into buf
}
var _ byteSrc = (*bufferedByteSrc)(nil)
// Read reads up to len(p) bytes from the in-memory buffer, advancing the offset
// by the number of bytes read.
func (b *bufferedByteSrc) readExact(p []byte) (int, error) {
if b.offset >= int64(len(b.buf)) {
return 0, io.EOF
}
n := copy(p, b.buf[b.offset:])
b.offset += int64(n)
return n, nil
}
// ReadByte returns the single byte at buf[offset] and advances offset by 1.
func (b *bufferedByteSrc) ReadByte() (byte, error) {
if b.offset >= int64(len(b.buf)) {
return 0, io.EOF
}
b.offset++
return b.buf[b.offset-1], nil
}
// peek returns buf[offset:offset+n] without advancing offset.
func (b *bufferedByteSrc) peek(n int) ([]byte, error) {
// Ensure we don't read past the end of the buffer.
if int64(n)+b.offset > int64(len(b.buf)) {
return b.buf[b.offset:], io.EOF
}
// Return the next n bytes without advancing the offset
return b.buf[b.offset : b.offset+int64(n)], nil
}
// discard advances offset by n bytes, returning the number of bytes discarded.
func (b *bufferedByteSrc) discard(n int) (int, error) {
// Ensure we don't read past the end of the buffer.
if int64(n)+b.offset > int64(len(b.buf)) {
// If we have exceeded the buffer length, discard only up to the end.
left := len(b.buf) - int(b.offset)
b.offset = int64(len(b.buf))
return left, io.EOF
}
// Advance the read position
b.offset += int64(n)
return n, nil
}
// readSlice scans buf[offset:] for the first occurrence of delim, returns
// buf[offset:idx+1], and advances offset past it; errors if delim not found.
func (b *bufferedByteSrc) readSlice(delim byte) ([]byte, error) {
// Ensure we don't read past the end of the buffer.
if b.offset >= int64(len(b.buf)) {
return nil, io.EOF
}
// Look for the delimiter in the remaining bytes
rem := b.buf[b.offset:]
idx := bytes.IndexByte(rem, delim)
if idx < 0 {
return nil, io.EOF
}
// Build the result slice up through the delimiter.
result := rem[:idx+1]
// Advance the offset past the delimiter.
b.offset += int64(idx + 1)
return result, nil
}
// pos returns the current read position in the buffer.
func (b *bufferedByteSrc) pos() int64 {
return b.offset
}
// regexLength will return the total byte length of a BSON regex value.
func (b *bufferedByteSrc) regexLength() (int32, error) {
rem := b.buf[b.offset:]
// Find end of the first C-string (pattern).
i := bytes.IndexByte(rem, 0x00)
if i < 0 {
return 0, io.EOF
}
// Find end of second C-string (options).
j := bytes.IndexByte(rem[i+1:], 0x00)
if j < 0 {
return 0, io.EOF
}
// Total length = first C-string length (pattern) + second C-string length
// (options) + 2 null terminators
return int32(i + j + 2), nil
}
func (*bufferedByteSrc) streamable() bool {
return false
}
func (b *bufferedByteSrc) reset() {
b.buf = nil
b.offset = 0
}
+97
View File
@@ -0,0 +1,97 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
)
// byteSliceCodec is the Codec used for []byte values.
type byteSliceCodec struct {
// encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values
// instead of BSON null.
encodeNilAsEmpty bool
}
// Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be
// used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &byteSliceCodec{}
// EncodeValue is the ValueEncoder for []byte.
func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tByteSlice {
return emptyValue, ValueDecoderError{
Name: "ByteSliceDecodeValue",
Types: []reflect.Type{tByteSlice},
Received: reflect.Zero(t),
}
}
var data []byte
var err error
switch vrType := vr.Type(); vrType {
case TypeString:
str, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
data = []byte(str)
case TypeSymbol:
sym, err := vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
data = []byte(sym)
case TypeBinary:
var subtype byte
data, subtype, err = vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"}
}
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(data), nil
}
// DecodeValue is the ValueDecoder for []byte.
func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tByteSlice {
return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
elem, err := bsc.decodeType(dc, vr, tByteSlice)
if err != nil {
return err
}
val.Set(elem)
return nil
}
+168
View File
@@ -0,0 +1,168 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"sync"
"sync/atomic"
)
// Runtime check that the kind encoder and decoder caches can store any valid
// reflect.Kind constant.
func init() {
if s := reflect.Kind(len(kindEncoderCache{}.entries)).String(); s != "kind27" {
panic("The capacity of kindEncoderCache is too small.\n" +
"This is due to a new type being added to reflect.Kind.")
}
}
// statically assert array size
var (
_ = (kindEncoderCache{}).entries[reflect.UnsafePointer]
_ = (kindDecoderCache{}).entries[reflect.UnsafePointer]
)
type typeEncoderCache struct {
cache sync.Map // map[reflect.Type]ValueEncoder
}
func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) {
c.cache.Store(rt, enc)
}
func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) {
if v, _ := c.cache.Load(rt); v != nil {
return v.(ValueEncoder), true
}
return nil, false
}
func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder {
if v, loaded := c.cache.LoadOrStore(rt, enc); loaded {
enc = v.(ValueEncoder)
}
return enc
}
func (c *typeEncoderCache) Clone() *typeEncoderCache {
cc := new(typeEncoderCache)
c.cache.Range(func(k, v any) bool {
if k != nil && v != nil {
cc.cache.Store(k, v)
}
return true
})
return cc
}
type typeDecoderCache struct {
cache sync.Map // map[reflect.Type]ValueDecoder
}
func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) {
c.cache.Store(rt, dec)
}
func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) {
if v, _ := c.cache.Load(rt); v != nil {
return v.(ValueDecoder), true
}
return nil, false
}
func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder {
if v, loaded := c.cache.LoadOrStore(rt, dec); loaded {
dec = v.(ValueDecoder)
}
return dec
}
func (c *typeDecoderCache) Clone() *typeDecoderCache {
cc := new(typeDecoderCache)
c.cache.Range(func(k, v any) bool {
if k != nil && v != nil {
cc.cache.Store(k, v)
}
return true
})
return cc
}
// atomic.Value requires that all calls to Store() have the same concrete type
// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type
// is always the same (since different concrete types may implement the
// ValueEncoder interface).
type kindEncoderCacheEntry struct {
enc ValueEncoder
}
type kindEncoderCache struct {
entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry
}
func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) {
if enc != nil && rt < reflect.Kind(len(c.entries)) {
c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc})
}
}
func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) {
if rt < reflect.Kind(len(c.entries)) {
if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok {
return ent.enc, ent.enc != nil
}
}
return nil, false
}
func (c *kindEncoderCache) Clone() *kindEncoderCache {
cc := new(kindEncoderCache)
for i, v := range c.entries {
if val := v.Load(); val != nil {
cc.entries[i].Store(val)
}
}
return cc
}
// atomic.Value requires that all calls to Store() have the same concrete type
// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type
// is always the same (since different concrete types may implement the
// ValueDecoder interface).
type kindDecoderCacheEntry struct {
dec ValueDecoder
}
type kindDecoderCache struct {
entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry
}
func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) {
if rt < reflect.Kind(len(c.entries)) {
c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec})
}
}
func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) {
if rt < reflect.Kind(len(c.entries)) {
if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok {
return ent.dec, ent.dec != nil
}
}
return nil, false
}
func (c *kindDecoderCache) Clone() *kindDecoderCache {
cc := new(kindDecoderCache)
for i, v := range c.entries {
if val := v.Load(); val != nil {
cc.entries[i].Store(val)
}
}
return cc
}
+61
View File
@@ -0,0 +1,61 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
type condAddrEncoder struct {
canAddrEnc ValueEncoder
elseEnc ValueEncoder
}
var _ ValueEncoder = &condAddrEncoder{}
// newCondAddrEncoder returns an condAddrEncoder.
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return &encoder
}
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.CanAddr() {
return cae.canAddrEnc.EncodeValue(ec, vw, val)
}
if cae.elseEnc != nil {
return cae.elseEnc.EncodeValue(ec, vw, val)
}
return errNoEncoder{Type: val.Type()}
}
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
type condAddrDecoder struct {
canAddrDec ValueDecoder
elseDec ValueDecoder
}
var _ ValueDecoder = &condAddrDecoder{}
// newCondAddrDecoder returns an CondAddrDecoder.
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
return &decoder
}
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if val.CanAddr() {
return cad.canAddrDec.DecodeValue(dc, vr, val)
}
if cad.elseDec != nil {
return cad.elseDec.DecodeValue(dc, vr, val)
}
return errNoDecoder{Type: val.Type()}
}
+433
View File
@@ -0,0 +1,433 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"io"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// copyDocument handles copying one document from the src to the dst.
func copyDocument(dst ValueWriter, src ValueReader) error {
dr, err := src.ReadDocument()
if err != nil {
return err
}
dw, err := dst.WriteDocument()
if err != nil {
return err
}
return copyDocumentCore(dw, dr)
}
// copyArrayFromBytes copies the values from a BSON array represented as a
// []byte to a ValueWriter.
func copyArrayFromBytes(dst ValueWriter, src []byte) error {
aw, err := dst.WriteArray()
if err != nil {
return err
}
err = copyBytesToArrayWriter(aw, src)
if err != nil {
return err
}
return aw.WriteArrayEnd()
}
// copyDocumentFromBytes copies the values from a BSON document represented as a
// []byte to a ValueWriter.
func copyDocumentFromBytes(dst ValueWriter, src []byte) error {
dw, err := dst.WriteDocument()
if err != nil {
return err
}
err = copyBytesToDocumentWriter(dw, src)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
type writeElementFn func(key string) (ValueWriter, error)
// copyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an
// ArrayWriter.
func copyBytesToArrayWriter(dst ArrayWriter, src []byte) error {
wef := func(_ string) (ValueWriter, error) {
return dst.WriteArrayElement()
}
return copyBytesToValueWriter(src, wef)
}
// copyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
// DocumentWriter.
func copyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
wef := func(key string) (ValueWriter, error) {
return dst.WriteDocumentElement(key)
}
return copyBytesToValueWriter(src, wef)
}
func copyBytesToValueWriter(src []byte, wef writeElementFn) error {
// TODO(skriptble): Create errors types here. Anything that is a tag should be a property.
length, rem, ok := bsoncore.ReadLength(src)
if !ok {
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
}
if len(src) < int(length) {
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
}
rem = rem[:length-4]
var t bsoncore.Type
var key string
var val bsoncore.Value
for {
t, rem, ok = bsoncore.ReadType(rem)
if !ok {
return io.EOF
}
if t == bsoncore.Type(0) {
if len(rem) != 0 {
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
}
break
}
key, rem, ok = bsoncore.ReadKey(rem)
if !ok {
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
}
// write as either array element or document element using writeElementFn
vw, err := wef(key)
if err != nil {
return err
}
val, rem, ok = bsoncore.ReadValue(rem, t)
if !ok {
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
}
err = copyValueFromBytes(vw, Type(t), val.Data)
if err != nil {
return err
}
}
return nil
}
// copyDocumentToBytes copies an entire document from the ValueReader and
// returns it as bytes.
func copyDocumentToBytes(src ValueReader) ([]byte, error) {
return appendDocumentBytes(nil, src)
}
// appendDocumentBytes functions the same as CopyDocumentToBytes, but will
// append the result to dst.
func appendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(bytesReader); ok {
_, dst, err := br.readValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(dst)
err := copyDocument(vw, src)
dst = vw.buf
return dst, err
}
// appendArrayBytes copies an array from the ValueReader to dst.
func appendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(bytesReader); ok {
_, dst, err := br.readValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(dst)
err := copyArray(vw, src)
dst = vw.buf
return dst, err
}
// copyValueFromBytes will write the value represtend by t and src to dst.
func copyValueFromBytes(dst ValueWriter, t Type, src []byte) error {
if wvb, ok := dst.(bytesWriter); ok {
return wvb.writeValueBytes(t, src)
}
vr := newBufferedDocumentReader(src)
vr.advanceFrame()
vr.stack[vr.frame].mode = mElement
vr.stack[vr.frame].vType = t
return copyValue(dst, vr)
}
// copyValueToBytes copies a value from src and returns it as a Type and a
// []byte.
func copyValueToBytes(src ValueReader) (Type, []byte, error) {
if br, ok := src.(bytesReader); ok {
return br.readValueBytes(nil)
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(nil)
vw.push(mElement)
err := copyValue(vw, src)
if err != nil {
return 0, nil, err
}
return Type(vw.buf[0]), vw.buf[2:], nil
}
// copyValue will copy a single value from src to dst.
func copyValue(dst ValueWriter, src ValueReader) error {
var err error
switch src.Type() {
case TypeDouble:
var f64 float64
f64, err = src.ReadDouble()
if err != nil {
break
}
err = dst.WriteDouble(f64)
case TypeString:
var str string
str, err = src.ReadString()
if err != nil {
return err
}
err = dst.WriteString(str)
case TypeEmbeddedDocument:
err = copyDocument(dst, src)
case TypeArray:
err = copyArray(dst, src)
case TypeBinary:
var data []byte
var subtype byte
data, subtype, err = src.ReadBinary()
if err != nil {
break
}
err = dst.WriteBinaryWithSubtype(data, subtype)
case TypeUndefined:
err = src.ReadUndefined()
if err != nil {
break
}
err = dst.WriteUndefined()
case TypeObjectID:
var oid ObjectID
oid, err = src.ReadObjectID()
if err != nil {
break
}
err = dst.WriteObjectID(oid)
case TypeBoolean:
var b bool
b, err = src.ReadBoolean()
if err != nil {
break
}
err = dst.WriteBoolean(b)
case TypeDateTime:
var dt int64
dt, err = src.ReadDateTime()
if err != nil {
break
}
err = dst.WriteDateTime(dt)
case TypeNull:
err = src.ReadNull()
if err != nil {
break
}
err = dst.WriteNull()
case TypeRegex:
var pattern, options string
pattern, options, err = src.ReadRegex()
if err != nil {
break
}
err = dst.WriteRegex(pattern, options)
case TypeDBPointer:
var ns string
var pointer ObjectID
ns, pointer, err = src.ReadDBPointer()
if err != nil {
break
}
err = dst.WriteDBPointer(ns, pointer)
case TypeJavaScript:
var js string
js, err = src.ReadJavascript()
if err != nil {
break
}
err = dst.WriteJavascript(js)
case TypeSymbol:
var symbol string
symbol, err = src.ReadSymbol()
if err != nil {
break
}
err = dst.WriteSymbol(symbol)
case TypeCodeWithScope:
var code string
var srcScope DocumentReader
code, srcScope, err = src.ReadCodeWithScope()
if err != nil {
break
}
var dstScope DocumentWriter
dstScope, err = dst.WriteCodeWithScope(code)
if err != nil {
break
}
err = copyDocumentCore(dstScope, srcScope)
case TypeInt32:
var i32 int32
i32, err = src.ReadInt32()
if err != nil {
break
}
err = dst.WriteInt32(i32)
case TypeTimestamp:
var t, i uint32
t, i, err = src.ReadTimestamp()
if err != nil {
break
}
err = dst.WriteTimestamp(t, i)
case TypeInt64:
var i64 int64
i64, err = src.ReadInt64()
if err != nil {
break
}
err = dst.WriteInt64(i64)
case TypeDecimal128:
var d128 Decimal128
d128, err = src.ReadDecimal128()
if err != nil {
break
}
err = dst.WriteDecimal128(d128)
case TypeMinKey:
err = src.ReadMinKey()
if err != nil {
break
}
err = dst.WriteMinKey()
case TypeMaxKey:
err = src.ReadMaxKey()
if err != nil {
break
}
err = dst.WriteMaxKey()
default:
err = fmt.Errorf("cannot copy unknown BSON type %s", src.Type())
}
return err
}
func copyArray(dst ValueWriter, src ValueReader) error {
ar, err := src.ReadArray()
if err != nil {
return err
}
aw, err := dst.WriteArray()
if err != nil {
return err
}
for {
vr, err := ar.ReadValue()
if errors.Is(err, ErrEOA) {
break
}
if err != nil {
return err
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = copyValue(vw, vr)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return err
}
vw, err := dw.WriteDocumentElement(key)
if err != nil {
return err
}
err = copyValue(vw, vr)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// bytesReader is the interface used to read BSON bytes from a valueReader.
//
// The bytes of the value will be appended to dst.
type bytesReader interface {
readValueBytes(dst []byte) (Type, []byte, error)
}
// bytesWriter is the interface used to write BSON bytes to a valueWriter.
type bytesWriter interface {
writeValueBytes(t Type, b []byte) error
}
+341
View File
@@ -0,0 +1,341 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import (
"encoding/json"
"errors"
"fmt"
"math/big"
"regexp"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/v2/internal/decimal128"
)
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
const (
MaxDecimal128Exp = 6111
MinDecimal128Exp = -6176
)
// These errors are returned when an invalid value is parsed as a big.Int.
var (
ErrParseNaN = errors.New("cannot parse NaN as a *big.Int")
ErrParseInf = errors.New("cannot parse Infinity as a *big.Int")
ErrParseNegInf = errors.New("cannot parse -Infinity as a *big.Int")
)
// Decimal128 holds decimal128 BSON values.
type Decimal128 struct {
h, l uint64
}
// NewDecimal128 creates a Decimal128 using the provide high and low uint64s.
func NewDecimal128(h, l uint64) Decimal128 {
return Decimal128{h: h, l: l}
}
// GetBytes returns the underlying bytes of the BSON decimal value as two uint64 values. The first
// contains the most first 8 bytes of the value and the second contains the latter.
func (d Decimal128) GetBytes() (uint64, uint64) {
return d.h, d.l
}
// String returns a string representation of the decimal value.
func (d Decimal128) String() string {
return decimal128.String(d.h, d.l)
}
// BigInt returns significand as big.Int and exponent, bi * 10 ^ exp.
func (d Decimal128) BigInt() (*big.Int, int, error) {
high, low := d.GetBytes()
posSign := high>>63&1 == 0 // positive sign
switch high >> 58 & (1<<5 - 1) {
case 0x1F:
return nil, 0, ErrParseNaN
case 0x1E:
if posSign {
return nil, 0, ErrParseInf
}
return nil, 0, ErrParseNegInf
}
var exp int
if high>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
exp = int(high >> 47 & (1<<14 - 1))
// Spec says all of these values are out of range.
high, low = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
exp = int(high >> 49 & (1<<14 - 1))
high &= (1<<49 - 1)
}
exp += MinDecimal128Exp
// Would be handled by the logic below, but that's trivial and common.
if high == 0 && low == 0 && exp == 0 {
return new(big.Int), 0, nil
}
bi := big.NewInt(0)
const host32bit = ^uint(0)>>32 == 0
if host32bit {
bi.SetBits([]big.Word{big.Word(low), big.Word(low >> 32), big.Word(high), big.Word(high >> 32)})
} else {
bi.SetBits([]big.Word{big.Word(low), big.Word(high)})
}
if !posSign {
return bi.Neg(bi), exp, nil
}
return bi, exp, nil
}
// IsNaN returns whether d is NaN.
func (d Decimal128) IsNaN() bool {
return d.h>>58&(1<<5-1) == 0x1F
}
// IsInf returns:
//
// +1 d == Infinity
// 0 other case
// -1 d == -Infinity
func (d Decimal128) IsInf() int {
if d.h>>58&(1<<5-1) != 0x1E {
return 0
}
if d.h>>63&1 == 0 {
return 1
}
return -1
}
// IsZero returns true if d is the empty Decimal128.
func (d Decimal128) IsZero() bool {
return d.h == 0 && d.l == 0
}
// MarshalJSON returns Decimal128 as a string.
func (d Decimal128) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON creates a Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string
// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will
// be unchanged.
func (d *Decimal128) UnmarshalJSON(b []byte) error {
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
// enter the UnmarshalJSON hook.
if string(b) == "null" {
return nil
}
var res any
err := json.Unmarshal(b, &res)
if err != nil {
return err
}
str, ok := res.(string)
// Extended JSON
if !ok {
m, ok := res.(map[string]any)
if !ok {
return errors.New("not an extended JSON Decimal128: expected document")
}
d128, ok := m["$numberDecimal"]
if !ok {
return errors.New("not an extended JSON Decimal128: expected key $numberDecimal")
}
str, ok = d128.(string)
if !ok {
return errors.New("not an extended JSON Decimal128: expected decimal to be string")
}
}
*d, err = ParseDecimal128(str)
return err
}
var (
dNaN = Decimal128{0x1F << 58, 0}
dPosInf = Decimal128{0x1E << 58, 0}
dNegInf = Decimal128{0x3E << 58, 0}
)
func dErr(s string) (Decimal128, error) {
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
}
// match scientific notation number, example -10.15e-18
var normalNumber = regexp.MustCompile(`^(?P<int>[-+]?\d*)?(?:\.(?P<dec>\d*))?(?:[Ee](?P<exp>[-+]?\d+))?$`)
// ParseDecimal128 takes the given string and attempts to parse it into a valid
// Decimal128 value.
func ParseDecimal128(s string) (Decimal128, error) {
if s == "" {
return dErr(s)
}
matches := normalNumber.FindStringSubmatch(s)
if len(matches) == 0 {
orig := s
neg := s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
}
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
return dNaN, nil
}
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
if neg {
return dNegInf, nil
}
return dPosInf, nil
}
return dErr(orig)
}
intPart := matches[1]
decPart := matches[2]
expPart := matches[3]
var err error
exp := 0
if expPart != "" {
exp, err = strconv.Atoi(expPart)
if err != nil {
return dErr(s)
}
}
if decPart != "" {
exp -= len(decPart)
}
if len(strings.Trim(intPart+decPart, "-0")) > 35 {
return dErr(s)
}
// Parse the significand (i.e. the non-exponent part) as a big.Int.
bi, ok := new(big.Int).SetString(intPart+decPart, 10)
if !ok {
return dErr(s)
}
d, ok := ParseDecimal128FromBigInt(bi, exp)
if !ok {
return dErr(s)
}
if bi.Sign() == 0 && s[0] == '-' {
d.h |= 1 << 63
}
return d, nil
}
var (
ten = big.NewInt(10)
zero = new(big.Int)
maxS, _ = new(big.Int).SetString("9999999999999999999999999999999999", 10)
)
// ParseDecimal128FromBigInt attempts to parse the given significand and exponent into a valid Decimal128 value.
func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) {
// copy
bi = new(big.Int).Set(bi)
q := new(big.Int)
r := new(big.Int)
// If the significand is zero, the logical value will always be zero, independent of the
// exponent. However, the loops for handling out-of-range exponent values below may be extremely
// slow for zero values because the significand never changes. Limit the exponent value to the
// supported range here to prevent entering the loops below.
if bi.Cmp(zero) == 0 {
if exp > MaxDecimal128Exp {
exp = MaxDecimal128Exp
}
if exp < MinDecimal128Exp {
exp = MinDecimal128Exp
}
}
for bigIntCmpAbs(bi, maxS) == 1 {
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
if exp > MaxDecimal128Exp {
return Decimal128{}, false
}
}
for exp < MinDecimal128Exp {
// Subnormal.
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
}
for exp > MaxDecimal128Exp {
// Clamped.
bi.Mul(bi, ten)
if bigIntCmpAbs(bi, maxS) == 1 {
return Decimal128{}, false
}
exp--
}
b := bi.Bytes()
var h, l uint64
for i := 0; i < len(b); i++ {
if i < len(b)-8 {
h = h<<8 | uint64(b[i])
continue
}
l = l<<8 | uint64(b[i])
}
h |= uint64(exp-MinDecimal128Exp) & uint64(1<<14-1) << 49
if bi.Sign() == -1 {
h |= 1 << 63
}
return Decimal128{h: h, l: l}, true
}
// bigIntCmpAbs computes big.Int.Cmp(absoluteValue(x), absoluteValue(y)).
func bigIntCmpAbs(x, y *big.Int) int {
xAbs := bigIntAbsValue(x)
yAbs := bigIntAbsValue(y)
return xAbs.Cmp(yAbs)
}
// bigIntAbsValue returns a big.Int containing the absolute value of b.
// If b is already a non-negative number, it is returned without any changes or copies.
func bigIntAbsValue(b *big.Int) *big.Int {
if b.Sign() >= 0 {
return b // already positive
}
return new(big.Int).Abs(b)
}
+143
View File
@@ -0,0 +1,143 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sync"
)
// ErrDecodeToNil is the error returned when trying to decode to a nil value
var ErrDecodeToNil = errors.New("cannot Decode to nil value")
// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Decoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var decPool = sync.Pool{
New: func() any {
return new(Decoder)
},
}
// A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as
// the source of BSON data.
type Decoder struct {
dc DecodeContext
vr ValueReader
}
// NewDecoder returns a new decoder that reads from vr.
func NewDecoder(vr ValueReader) *Decoder {
return &Decoder{
dc: DecodeContext{Registry: defaultRegistry},
vr: vr,
}
}
// Decode reads the next BSON document from the stream and decodes it into the
// value pointed to by val.
//
// See [Unmarshal] for details about BSON unmarshaling behavior.
func (d *Decoder) Decode(val any) error {
if unmarshaler, ok := val.(Unmarshaler); ok {
// TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method.
buf, err := copyDocumentToBytes(d.vr)
if err != nil {
return err
}
return unmarshaler.UnmarshalBSON(buf)
}
rval := reflect.ValueOf(val)
switch rval.Kind() {
case reflect.Ptr:
if rval.IsNil() {
return ErrDecodeToNil
}
rval = rval.Elem()
case reflect.Map:
if rval.IsNil() {
return ErrDecodeToNil
}
default:
return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval)
}
decoder, err := d.dc.LookupDecoder(rval.Type())
if err != nil {
return err
}
return decoder.DecodeValue(d.dc, d.vr, rval)
}
// Reset will reset the state of the decoder, using the same *DecodeContext used in
// the original construction but using vr for reading.
func (d *Decoder) Reset(vr ValueReader) {
d.vr = vr
}
// SetRegistry replaces the current registry of the decoder with r.
func (d *Decoder) SetRegistry(r *Registry) {
d.dc.Registry = r
}
// DefaultDocumentM causes the Decoder to always unmarshal documents into the bson.M type. This
// behavior is restricted to data typed as "any" or "map[string]any".
func (d *Decoder) DefaultDocumentM() {
d.dc.defaultDocumentType = reflect.TypeOf(M{})
}
// DefaultDocumentMap causes the Decoder to always unmarshal documents into the
// map[string]any type. This behavior is restricted to data typed as "any" or
// "map[string]any".
func (d *Decoder) DefaultDocumentMap() {
d.dc.defaultDocumentType = reflect.TypeOf(map[string]any{})
}
// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values
// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct
// field. The truncation logic does not apply to BSON "decimal128" values.
func (d *Decoder) AllowTruncatingDoubles() {
d.dc.truncate = true
}
// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or
// "Old" BSON binary subtype as a Go byte slice instead of a bson.Binary.
func (d *Decoder) BinaryAsSlice() {
d.dc.binaryAsSlice = true
}
// ObjectIDAsHexString causes the Decoder to decode object IDs to their hex representation.
func (d *Decoder) ObjectIDAsHexString() {
d.dc.objectIDAsHexString = true
}
// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson"
// struct tag is not specified.
func (d *Decoder) UseJSONStructTags() {
d.dc.useJSONStructTags = true
}
// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead
// of the UTC timezone.
func (d *Decoder) UseLocalTimeZone() {
d.dc.useLocalTimeZone = true
}
// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value
// passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroMaps() {
d.dc.zeroMaps = true
}
// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination
// value passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroStructs() {
d.dc.zeroStructs = true
}
File diff suppressed because it is too large Load Diff
+518
View File
@@ -0,0 +1,518 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/json"
"errors"
"math"
"net/url"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
var bvwPool = sync.Pool{
New: func() any {
return new(valueWriter)
},
}
var errInvalidValue = errors.New("cannot encode invalid element")
var sliceWriterPool = sync.Pool{
New: func() any {
sw := make(sliceWriter, 0)
return &sw
},
}
func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error {
vw, err := dw.WriteDocumentElement(e.Key)
if err != nil {
return err
}
if e.Value == nil {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
if err != nil {
return err
}
return nil
}
// registerDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
// the provided RegistryBuilder.
func registerDefaultEncoders(reg *Registry) {
mapEncoder := &mapCodec{}
uintCodec := &uintCodec{}
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{})
reg.RegisterTypeEncoder(tTime, &timeCodec{})
reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{})
reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{})
reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue))
reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue))
reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue))
reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue))
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue))
reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue))
reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue))
reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue))
reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue))
reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue))
reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue))
reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue))
reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Uint, uintCodec)
reg.RegisterKindEncoder(reflect.Uint8, uintCodec)
reg.RegisterKindEncoder(reflect.Uint16, uintCodec)
reg.RegisterKindEncoder(reflect.Uint32, uintCodec)
reg.RegisterKindEncoder(reflect.Uint64, uintCodec)
reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue))
reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue))
reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue))
reg.RegisterKindEncoder(reflect.Map, mapEncoder)
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{})
reg.RegisterKindEncoder(reflect.String, &stringCodec{})
reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder))
reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{})
reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue))
reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue))
}
// booleanEncodeValue is the ValueEncoderFunc for bool types.
func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Bool {
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
}
return vw.WriteBoolean(val.Bool())
}
func fitsIn32Bits(i int64) bool {
return math.MinInt32 <= i && i <= math.MaxInt32
}
// intEncodeValue is the ValueEncoderFunc for int types.
func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32:
return vw.WriteInt32(int32(val.Int()))
case reflect.Int:
i64 := val.Int()
if fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
case reflect.Int64:
i64 := val.Int()
if ec.minSize && fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
}
return ValueEncoderError{
Name: "IntEncodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: val,
}
}
// floatEncodeValue is the ValueEncoderFunc for float types.
func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Float32, reflect.Float64:
return vw.WriteDouble(val.Float())
}
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
}
// objectIDEncodeValue is the ValueEncoderFunc for ObjectID.
func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tOID {
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
}
return vw.WriteObjectID(val.Interface().(ObjectID))
}
// decimal128EncodeValue is the ValueEncoderFunc for Decimal128.
func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDecimal {
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
}
return vw.WriteDecimal128(val.Interface().(Decimal128))
}
// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number.
func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJSONNumber {
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
}
jsnum := val.Interface().(json.Number)
// Attempt int first, then float64
if i64, err := jsnum.Int64(); err == nil {
return intEncodeValue(ec, vw, reflect.ValueOf(i64))
}
f64, err := jsnum.Float64()
if err != nil {
return err
}
return floatEncodeValue(ec, vw, reflect.ValueOf(f64))
}
// urlEncodeValue is the ValueEncoderFunc for url.URL.
func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tURL {
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
}
u := val.Interface().(url.URL)
return vw.WriteString(u.String())
}
// arrayEncodeValue is the ValueEncoderFunc for array types.
func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Array {
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
}
// If we have a []E we want to treat it as a document instead of as an array.
if val.Type().Elem() == tE {
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
e := val.Index(idx).Interface().(E)
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
var byteSlice []byte
for idx := 0; idx < val.Len(); idx++ {
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
}
return vw.WriteBinary(byteSlice)
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
return origEncoder, currVal, nil
}
currVal = currVal.Elem()
if !currVal.IsValid() {
return nil, currVal, errInvalidValue
}
currEncoder, err := ec.LookupEncoder(currVal.Type())
return currEncoder, currVal, err
}
// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement ValueMarshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
case val.Type().Implements(tValueMarshaler):
// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tValueMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}
m, ok := val.Interface().(ValueMarshaler)
if !ok {
return vw.WriteNull()
}
t, data, err := m.MarshalBSONValue()
if err != nil {
return err
}
return copyValueFromBytes(vw, Type(t), data)
}
// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Marshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
case val.Type().Implements(tMarshaler):
// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}
m, ok := val.Interface().(Marshaler)
if !ok {
return vw.WriteNull()
}
data, err := m.MarshalBSON()
if err != nil {
return err
}
return copyValueFromBytes(vw, TypeEmbeddedDocument, data)
}
// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type.
func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJavaScript {
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
}
return vw.WriteJavascript(val.String())
}
// symbolEncodeValue is the ValueEncoderFunc for the Symbol type.
func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tSymbol {
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
}
return vw.WriteSymbol(val.String())
}
// binaryEncodeValue is the ValueEncoderFunc for Binary.
func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tBinary {
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
}
b := val.Interface().(Binary)
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// vectorEncodeValue is the ValueEncoderFunc for Vector.
func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
t := val.Type()
if !val.IsValid() || t != tVector {
return ValueEncoderError{
Name: "VectorEncodeValue",
Types: []reflect.Type{tVector},
Received: val,
}
}
v := val.Interface().(Vector)
b := v.Binary()
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// undefinedEncodeValue is the ValueEncoderFunc for Undefined.
func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tUndefined {
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
}
return vw.WriteUndefined()
}
// dateTimeEncodeValue is the ValueEncoderFunc for DateTime.
func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDateTime {
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
}
return vw.WriteDateTime(val.Int())
}
// nullEncodeValue is the ValueEncoderFunc for Null.
func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tNull {
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
}
return vw.WriteNull()
}
// regexEncodeValue is the ValueEncoderFunc for Regex.
func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRegex {
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
}
regex := val.Interface().(Regex)
return vw.WriteRegex(regex.Pattern, regex.Options)
}
// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer.
func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDBPointer {
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
}
dbp := val.Interface().(DBPointer)
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
}
// timestampEncodeValue is the ValueEncoderFunc for Timestamp.
func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTimestamp {
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
}
ts := val.Interface().(Timestamp)
return vw.WriteTimestamp(ts.T, ts.I)
}
// minKeyEncodeValue is the ValueEncoderFunc for MinKey.
func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMinKey {
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
}
return vw.WriteMinKey()
}
// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMaxKey {
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
}
return vw.WriteMaxKey()
}
// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreDocument {
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
}
cdoc := val.Interface().(bsoncore.Document)
return copyDocumentFromBytes(vw, cdoc)
}
// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCodeWithScope {
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
}
cws := val.Interface().(CodeWithScope)
dw, err := vw.WriteCodeWithScope(string(cws.Code))
if err != nil {
return err
}
sw := sliceWriterPool.Get().(*sliceWriter)
defer sliceWriterPool.Put(sw)
*sw = (*sw)[:0]
scopeVW := bvwPool.Get().(*valueWriter)
scopeVW.reset(scopeVW.buf[:0])
scopeVW.w = sw
defer bvwPool.Put(scopeVW)
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
if err != nil {
return err
}
err = copyBytesToDocumentWriter(dw, *sw)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
vt := val.Type()
for vt.Kind() == reflect.Ptr {
vt = vt.Elem()
}
return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
}
+141
View File
@@ -0,0 +1,141 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bson is a library for reading, writing, and manipulating BSON. BSON is a binary serialization
// format used to store documents and make remote procedure calls in MongoDB. For more information about
// the Go BSON library, including usage examples, check out the [Work with BSON] page in the Go Driver
// docs site. For more information about BSON, see https://bsonspec.org.
//
// # Native Go Types
//
// The [D] and [M] types defined in this package can be used to build representations of BSON using native Go types. D is a
// slice and M is a map. For more information about the use cases for these types, see the documentation on the type
// definitions.
//
// Note that a D should not be constructed with duplicate key names, as that can cause undefined server behavior.
//
// Example:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
//
// When decoding BSON to a D or M, the following type mappings apply when unmarshaling:
//
// 1. BSON int32 unmarshals to an int32.
// 2. BSON int64 unmarshals to an int64.
// 3. BSON double unmarshals to a float64.
// 4. BSON string unmarshals to a string.
// 5. BSON boolean unmarshals to a bool.
// 6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M).
// 7. BSON array unmarshals to a bson.A.
// 8. BSON ObjectId unmarshals to a bson.ObjectID.
// 9. BSON datetime unmarshals to a bson.DateTime.
// 10. BSON binary unmarshals to a bson.Binary.
// 11. BSON regular expression unmarshals to a bson.Regex.
// 12. BSON JavaScript unmarshals to a bson.JavaScript.
// 13. BSON code with scope unmarshals to a bson.CodeWithScope.
// 14. BSON timestamp unmarshals to an bson.Timestamp.
// 15. BSON 128-bit decimal unmarshals to an bson.Decimal128.
// 16. BSON min key unmarshals to an bson.MinKey.
// 17. BSON max key unmarshals to an bson.MaxKey.
// 18. BSON undefined unmarshals to a bson.Undefined.
// 19. BSON null unmarshals to nil.
// 20. BSON DBPointer unmarshals to a bson.DBPointer.
// 21. BSON symbol unmarshals to a bson.Symbol.
//
// The above mappings also apply when marshaling a D or M to BSON. Some other useful marshaling mappings are:
//
// 1. time.Time marshals to a BSON datetime.
// 2. int8, int16, and int32 marshal to a BSON int32.
// 3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64
// otherwise.
// 4. int64 marshals to BSON int64 (unless [Encoder.IntMinSize] is set).
// 5. uint8 and uint16 marshal to a BSON int32.
// 6. uint, uint32, and uint64 marshal to a BSON int64 (unless [Encoder.IntMinSize] is set).
// 7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshaling a BSON null or
// undefined value into a string will yield the empty string.).
//
// # Structs
//
// Structs can be marshaled/unmarshaled to/from BSON or Extended JSON. When transforming structs to/from BSON or Extended
// JSON, the following rules apply:
//
// 1. Only exported fields in structs will be marshaled or unmarshaled.
//
// 2. When marshaling a struct, each field will be lowercased to generate the key for the corresponding BSON element.
// For example, a struct field named "Foo" will generate key "foo". This can be overridden via a struct tag (e.g.
// `bson:"fooField"` to generate key "fooField" instead).
//
// 3. An embedded struct field is marshaled as a subdocument. The key will be the lowercased name of the field's type.
//
// 4. A pointer field is marshaled as the underlying type if the pointer is non-nil. If the pointer is nil, it is
// marshaled as a BSON null value.
//
// 5. When unmarshaling, a field of type any will follow the D/M type mappings listed above. BSON documents
// unmarshaled into an any field will be unmarshaled as a D.
//
// The encoding of each struct field can be customized by the "bson" struct tag. The "bson" tag gives the name of the
// field, followed by a comma-separated list of options. The name may be omitted in order to specify options without
// overriding the default field name. The following options can be used to configure behavior:
//
// 1. omitempty: If the "omitempty" struct tag is specified on a field, the field will not be marshaled if it is set to
// an "empty" value. Numbers, booleans, and strings are considered empty if their value is equal to the zero value for
// the type (i.e. 0 for numbers, false for booleans, and "" for strings). Slices, maps, and arrays are considered
// empty if they are of length zero. Interfaces and pointers are considered empty if their value is nil. By default,
// structs are only considered empty if the struct type implements [Zeroer] and the "IsZero"
// method returns true. Struct types that do not implement [Zeroer] are never considered empty and will be
// marshaled as embedded documents. NOTE: It is recommended that this tag be used for all slice and map fields.
//
// 2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of
// the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For
// other types, this tag is ignored.
//
// 3. truncate: If the truncate struct tag is specified on a field with a non-float numeric type, BSON doubles
// unmarshaled into that field will be truncated at the decimal point. For example, if 3.14 is unmarshaled into a
// field of type int, it will be unmarshaled as 3. If this tag is not specified, the decoder will throw an error if
// the value cannot be decoded without losing precision. For float64 or non-numeric types, this tag is ignored.
//
// 4. inline: If the inline struct tag is specified for a struct or map field, the field will be "flattened" when
// marshaling and "un-flattened" when unmarshaling. This means that all of the fields in that struct/map will be
// pulled up one level and will become top-level fields rather than being fields in a nested document. For example,
// if a map field named "Map" with value map[string]any{"foo": "bar"} is inlined, the resulting document will
// be {"foo": "bar"} instead of {"map": {"foo": "bar"}}. There can only be one inlined map field in a struct. If
// there are duplicated fields in the resulting document when an inlined struct is marshaled, the inlined field will
// be overwritten. If there are duplicated fields in the resulting document when an inlined map is marshaled, an
// error will be returned. This tag can be used with fields that are pointers to structs. If an inlined pointer field
// is nil, it will not be marshaled. For fields that are not maps or structs, this tag is ignored.
//
// # Raw BSON
//
// The Raw family of types is used to validate and retrieve elements from a slice of bytes. This
// type is most useful when you want do lookups on BSON bytes without unmarshaling it into another
// type.
//
// Example:
//
// var raw bson.Raw = ... // bytes from somewhere
// err := raw.Validate()
// if err != nil { return err }
// val := raw.Lookup("foo")
// i32, ok := val.Int32OK()
// // do something with i32...
//
// # Custom Registry
//
// The Go BSON library uses a [Registry] to define encoding and decoding behavior for different data types.
// The default encoding and decoding behavior can be customized or extended by using a modified Registry.
// The custom registry system is composed of two parts:
//
// 1) [ValueEncoder] and [ValueDecoder] that handle encoding and decoding Go values to and from BSON
// representations.
//
// 2) A [Registry] that holds these ValueEncoders and ValueDecoders and provides methods for
// retrieving them.
//
// To use a custom Registry, use [Encoder.SetRegistry] or [Decoder.SetRegistry].
//
// [Work with BSON]: https://www.mongodb.com/docs/drivers/go/current/fundamentals/bson/
package bson
+127
View File
@@ -0,0 +1,127 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
// emptyInterfaceCodec is the Codec used for any values.
type emptyInterfaceCodec struct {
// decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the
// "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary.
decodeBinaryAsSlice bool
}
// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it
// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &emptyInterfaceCodec{}
// EncodeValue is the ValueEncoderFunc for any.
func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) {
isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument
if isDocument {
if dc.defaultDocumentType != nil {
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
// that type.
return dc.defaultDocumentType, nil
}
}
rtype, err := dc.LookupTypeMapEntry(valueType)
if err == nil {
return rtype, nil
}
if isDocument {
// For documents, fallback to looking up a type map entry for Type(0) or TypeEmbeddedDocument,
// depending on the original valueType.
var lookupType Type
switch valueType {
case Type(0):
lookupType = TypeEmbeddedDocument
case TypeEmbeddedDocument:
lookupType = Type(0)
}
rtype, err = dc.LookupTypeMapEntry(lookupType)
if err == nil {
return rtype, nil
}
// fallback to bson.D
return tD, nil
}
return nil, err
}
func (eic *emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tEmpty {
return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)}
}
rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type())
if err != nil {
switch vr.Type() {
case TypeNull:
return reflect.Zero(t), vr.ReadNull()
default:
return emptyValue, err
}
}
decoder, err := dc.LookupDecoder(rtype)
if err != nil {
return emptyValue, err
}
elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype)
if err != nil {
return emptyValue, err
}
if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary {
binElem := elem.Interface().(Binary)
if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld {
elem = reflect.ValueOf(binElem.Data)
}
}
return elem, nil
}
// DecodeValue is the ValueDecoderFunc for any.
func (eic *emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tEmpty {
return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
elem, err := eic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.Set(elem)
return nil
}
+130
View File
@@ -0,0 +1,130 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"sync"
)
// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Encoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var encPool = sync.Pool{
New: func() any {
return new(Encoder)
},
}
// An Encoder writes a serialization format to an output stream. It writes to a ValueWriter
// as the destination of BSON data.
type Encoder struct {
ec EncodeContext
vw ValueWriter
}
// NewEncoder returns a new encoder that writes to vw.
func NewEncoder(vw ValueWriter) *Encoder {
return &Encoder{
ec: EncodeContext{Registry: defaultRegistry},
vw: vw,
}
}
// Encode writes the BSON encoding of val to the stream.
//
// See [Marshal] for details about BSON marshaling behavior.
func (e *Encoder) Encode(val any) error {
if marshaler, ok := val.(Marshaler); ok {
// TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse?
buf, err := marshaler.MarshalBSON()
if err != nil {
return err
}
return copyDocumentFromBytes(e.vw, buf)
}
encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val))
if err != nil {
return err
}
return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val))
}
// Reset will reset the state of the Encoder, using the same *EncodeContext used in
// the original construction but using vw.
func (e *Encoder) Reset(vw ValueWriter) {
e.vw = vw
}
// SetRegistry replaces the current registry of the Encoder with r.
func (e *Encoder) SetRegistry(r *Registry) {
e.ec.Registry = r
}
// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in
// the marshaled BSON when the "inline" struct tag option is set.
func (e *Encoder) ErrorOnInlineDuplicates() {
e.ec.errorOnInlineDuplicates = true
}
// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint,
// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can
// represent the integer value.
func (e *Encoder) IntMinSize() {
e.ec.minSize = true
}
// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name
// strings using fmt.Sprint instead of the default string conversion logic.
func (e *Encoder) StringifyMapKeysWithFmt() {
e.ec.stringifyMapKeysWithFmt = true
}
// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON
// null.
func (e *Encoder) NilMapAsEmpty() {
e.ec.nilMapAsEmpty = true
}
// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON
// null.
func (e *Encoder) NilSliceAsEmpty() {
e.ec.nilSliceAsEmpty = true
}
// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values
// instead of BSON null.
func (e *Encoder) NilByteSliceAsEmpty() {
e.ec.nilByteSliceAsEmpty = true
}
// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported
// TODO struct fields once the logic is updated to also inspect private struct fields.
// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{})
// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set
// or the OmitEmpty() method is called.
//
// Note that the Encoder only examines exported struct fields when determining if a struct is the
// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty.
func (e *Encoder) OmitZeroStruct() {
e.ec.omitZeroStruct = true
}
// OmitEmpty causes the Encoder to omit empty values from the marshaled BSON as the "omitempty"
// struct tag option is set.
func (e *Encoder) OmitEmpty() {
e.ec.omitEmpty = true
}
// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson"
// struct tag is not specified.
func (e *Encoder) UseJSONStructTags() {
e.ec.useJSONStructTags = true
}
+803
View File
@@ -0,0 +1,803 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
const maxNestingDepth = 200
// ErrInvalidJSON indicates the JSON input is invalid
var ErrInvalidJSON = errors.New("invalid JSON input")
type jsonParseState byte
const (
jpsStartState jsonParseState = iota
jpsSawBeginObject
jpsSawEndObject
jpsSawBeginArray
jpsSawEndArray
jpsSawColon
jpsSawComma
jpsSawKey
jpsSawValue
jpsDoneState
jpsInvalidState
)
type jsonParseMode byte
const (
jpmInvalidMode jsonParseMode = iota
jpmObjectMode
jpmArrayMode
)
type extJSONValue struct {
t Type
v any
}
type extJSONObject struct {
keys []string
values []*extJSONValue
}
type extJSONParser struct {
js *jsonScanner
s jsonParseState
m []jsonParseMode
k string
v *extJSONValue
err error
canonicalOnly bool
depth int
maxDepth int
emptyObject bool
relaxedUUID bool
}
// newExtJSONParser returns a new extended JSON parser, ready to to begin
// parsing from the first character of the argued json input. It will not
// perform any read-ahead and will therefore not report any errors about
// malformed JSON at this point.
func newExtJSONParser(r io.Reader, canonicalOnly bool) *extJSONParser {
return &extJSONParser{
js: &jsonScanner{r: r},
s: jpsStartState,
m: []jsonParseMode{},
canonicalOnly: canonicalOnly,
maxDepth: maxNestingDepth,
}
}
// peekType examines the next value and returns its BSON Type
func (ejp *extJSONParser) peekType() (Type, error) {
var t Type
var err error
initialState := ejp.s
ejp.advanceState()
switch ejp.s {
case jpsSawValue:
t = ejp.v.t
case jpsSawBeginArray:
t = TypeArray
case jpsInvalidState:
err = ejp.err
case jpsSawComma:
// in array mode, seeing a comma means we need to progress again to actually observe a type
if ejp.peekMode() == jpmArrayMode {
return ejp.peekType()
}
case jpsSawEndArray:
// this would only be a valid state if we were in array mode, so return end-of-array error
err = ErrEOA
case jpsSawBeginObject:
// peek key to determine type
ejp.advanceState()
switch ejp.s {
case jpsSawEndObject: // empty embedded document
t = TypeEmbeddedDocument
ejp.emptyObject = true
case jpsInvalidState:
err = ejp.err
case jpsSawKey:
if initialState == jpsStartState {
return TypeEmbeddedDocument, nil
}
t = wrapperKeyBSONType(ejp.k)
// if $uuid is encountered, parse as binary subtype 4
if ejp.k == "$uuid" {
ejp.relaxedUUID = true
t = TypeBinary
}
switch t {
case TypeJavaScript:
// just saw $code, need to check for $scope at same level
_, err = ejp.readValue(TypeJavaScript)
if err != nil {
break
}
switch ejp.s {
case jpsSawEndObject: // type is TypeJavaScript
case jpsSawComma:
ejp.advanceState()
if ejp.s == jpsSawKey && ejp.k == "$scope" {
t = TypeCodeWithScope
} else {
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
}
case jpsInvalidState:
err = ejp.err
default:
err = ErrInvalidJSON
}
case TypeCodeWithScope:
err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope")
}
}
}
return t, err
}
// readKey parses the next key and its type and returns them
func (ejp *extJSONParser) readKey() (string, Type, error) {
if ejp.emptyObject {
ejp.emptyObject = false
return "", 0, ErrEOD
}
// advance to key (or return with error)
switch ejp.s {
case jpsStartState:
ejp.advanceState()
if ejp.s == jpsSawBeginObject {
ejp.advanceState()
}
case jpsSawBeginObject:
ejp.advanceState()
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject, jpsSawComma:
ejp.advanceState()
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsDoneState:
return "", 0, io.EOF
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, ErrInvalidJSON
}
case jpsSawKey: // do nothing (key was peeked before)
default:
return "", 0, invalidRequestError("key")
}
// read key
var key string
switch ejp.s {
case jpsSawKey:
key = ejp.k
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, invalidRequestError("key")
}
// check for colon
ejp.advanceState()
if err := ensureColon(ejp.s, key); err != nil {
return "", 0, err
}
// peek at the value to determine type
t, err := ejp.peekType()
if err != nil {
return "", 0, err
}
return key, t, nil
}
// readValue returns the value corresponding to the Type returned by peekType
func (ejp *extJSONParser) readValue(t Type) (*extJSONValue, error) {
if ejp.s == jpsInvalidState {
return nil, ejp.err
}
var v *extJSONValue
switch t {
case TypeNull, TypeBoolean, TypeString:
if ejp.s != jpsSawValue {
return nil, invalidRequestError(t.String())
}
v = ejp.v
case TypeInt32, TypeInt64, TypeDouble:
// relaxed version allows these to be literal number values
if ejp.s == jpsSawValue {
v = ejp.v
break
}
fallthrough
case TypeDecimal128, TypeSymbol, TypeObjectID, TypeMinKey, TypeMaxKey, TypeUndefined:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("} after value", t)
}
default:
return nil, invalidRequestError(t.String())
}
case TypeBinary, TypeRegex, TypeTimestamp, TypeDBPointer:
if ejp.s != jpsSawKey {
return nil, invalidRequestError(t.String())
}
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
if t == TypeBinary && ejp.s == jpsSawValue {
// convert relaxed $uuid format
if ejp.relaxedUUID {
defer func() { ejp.relaxedUUID = false }()
uuid, err := ejp.v.parseSymbol()
if err != nil {
return nil, err
}
// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing
// in the 8th, 13th, 18th, and 23rd characters.
//
// See https://tools.ietf.org/html/rfc4122#section-3
valid := len(uuid) == 36 &&
string(uuid[8]) == "-" &&
string(uuid[13]) == "-" &&
string(uuid[18]) == "-" &&
string(uuid[23]) == "-"
if !valid {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// remove hyphens
uuidNoHyphens := strings.ReplaceAll(uuid, "-", "")
if len(uuidNoHyphens) != 32 {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// convert hex to bytes
bytes, err := hex.DecodeString(uuidNoHyphens)
if err != nil {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("$uuid and value and then }", TypeBinary)
}
base64 := &extJSONValue{
t: TypeString,
v: base64.StdEncoding.EncodeToString(bytes),
}
subType := &extJSONValue{
t: TypeString,
v: "04",
}
v = &extJSONValue{
t: TypeEmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// convert legacy $binary format
base64 := ejp.v
ejp.advanceState()
if ejp.s != jpsSawComma {
return nil, invalidJSONErrorForType(",", TypeBinary)
}
ejp.advanceState()
key, t, err := ejp.readKey()
if err != nil {
return nil, err
}
if key != "$type" {
return nil, invalidJSONErrorForType("$type", TypeBinary)
}
subType, err := ejp.readValue(t)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", TypeBinary)
}
v = &extJSONValue{
t: TypeEmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// read KV pairs
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONErrorForType("{", t)
}
keys, vals, err := ejp.readObject(2, true)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
}
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case TypeDateTime:
switch ejp.s {
case jpsSawValue:
v = ejp.v
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject:
keys, vals, err := ejp.readObject(1, true)
if err != nil {
return nil, err
}
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case jpsSawValue:
if ejp.canonicalOnly {
return nil, invalidJSONError("{")
}
v = ejp.v
default:
if ejp.canonicalOnly {
return nil, invalidJSONErrorForType("object", t)
}
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("value and then }", t)
}
default:
return nil, invalidRequestError(t.String())
}
case TypeJavaScript:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object or comma and just return
ejp.advanceState()
case jpsSawEndObject:
v = ejp.v
default:
return nil, invalidRequestError(t.String())
}
case TypeCodeWithScope:
if ejp.s == jpsSawKey && ejp.k == "$scope" {
v = ejp.v // this is the $code string from earlier
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONError("$scope to be embedded document")
}
} else {
return nil, invalidRequestError(t.String())
}
case TypeEmbeddedDocument, TypeArray:
return nil, invalidRequestError(t.String())
}
return v, nil
}
// readObject is a utility method for reading full objects of known (or expected) size
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
keys := make([]string, numKeys)
vals := make([]*extJSONValue, numKeys)
if !started {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, nil, invalidJSONError("{")
}
}
for i := 0; i < numKeys; i++ {
key, t, err := ejp.readKey()
if err != nil {
return nil, nil, err
}
switch ejp.s {
case jpsSawKey:
v, err := ejp.readValue(t)
if err != nil {
return nil, nil, err
}
keys[i] = key
vals[i] = v
case jpsSawValue:
keys[i] = key
vals[i] = ejp.v
default:
return nil, nil, invalidJSONError("value")
}
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, nil, invalidJSONError("}")
}
return keys, vals, nil
}
// advanceState reads the next JSON token from the scanner and transitions
// from the current state based on that token's type
func (ejp *extJSONParser) advanceState() {
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
return
}
jt, err := ejp.js.nextToken()
if err != nil {
ejp.err = err
ejp.s = jpsInvalidState
return
}
valid := ejp.validateToken(jt.t)
if !valid {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
return
}
switch jt.t {
case jttBeginObject:
ejp.s = jpsSawBeginObject
ejp.pushMode(jpmObjectMode)
ejp.depth++
if ejp.depth > ejp.maxDepth {
ejp.err = nestingDepthError(jt.p, ejp.depth)
ejp.s = jpsInvalidState
}
case jttEndObject:
ejp.s = jpsSawEndObject
ejp.depth--
if ejp.popMode() != jpmObjectMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttBeginArray:
ejp.s = jpsSawBeginArray
ejp.pushMode(jpmArrayMode)
case jttEndArray:
ejp.s = jpsSawEndArray
if ejp.popMode() != jpmArrayMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttColon:
ejp.s = jpsSawColon
case jttComma:
ejp.s = jpsSawComma
case jttEOF:
ejp.s = jpsDoneState
if len(ejp.m) != 0 {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttString:
switch ejp.s {
case jpsSawComma:
if ejp.peekMode() == jpmArrayMode {
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
return
}
fallthrough
case jpsSawBeginObject:
ejp.s = jpsSawKey
ejp.k = jt.v.(string)
return
}
fallthrough
default:
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
}
}
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
jpsStartState: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
jttEOF: true,
},
jpsSawBeginObject: {
jttEndObject: true,
jttString: true,
},
jpsSawEndObject: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawBeginArray: {
jttBeginObject: true,
jttBeginArray: true,
jttEndArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawEndArray: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawColon: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawComma: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawKey: {
jttColon: true,
},
jpsSawValue: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsDoneState: {},
jpsInvalidState: {},
}
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
switch ejp.s {
case jpsSawEndObject:
// if we are at depth zero and the next token is a '{',
// we can consider it valid only if we are not in array mode.
if jtt == jttBeginObject && ejp.depth == 0 {
return ejp.peekMode() != jpmArrayMode
}
case jpsSawComma:
switch ejp.peekMode() {
// the only valid next token after a comma inside a document is a string (a key)
case jpmObjectMode:
return jtt == jttString
case jpmInvalidMode:
return false
}
}
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
return ok
}
// ensureExtValueType returns true if the current value has the expected
// value type for single-key extended JSON types. For example,
// {"$numberInt": v} v must be TypeString
func (ejp *extJSONParser) ensureExtValueType(t Type) bool {
switch t {
case TypeMinKey, TypeMaxKey:
return ejp.v.t == TypeInt32
case TypeUndefined:
return ejp.v.t == TypeBoolean
case TypeInt32, TypeInt64, TypeDouble, TypeDecimal128, TypeSymbol, TypeObjectID:
return ejp.v.t == TypeString
default:
return false
}
}
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
ejp.m = append(ejp.m, m)
}
func (ejp *extJSONParser) popMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
m := ejp.m[l-1]
ejp.m = ejp.m[:l-1]
return m
}
func (ejp *extJSONParser) peekMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
return ejp.m[l-1]
}
func extendJSONToken(jt *jsonToken) *extJSONValue {
var t Type
switch jt.t {
case jttInt32:
t = TypeInt32
case jttInt64:
t = TypeInt64
case jttDouble:
t = TypeDouble
case jttString:
t = TypeString
case jttBool:
t = TypeBoolean
case jttNull:
t = TypeNull
default:
return nil
}
return &extJSONValue{t: t, v: jt.v}
}
func ensureColon(s jsonParseState, key string) error {
if s != jpsSawColon {
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
}
return nil
}
func invalidRequestError(s string) error {
return fmt.Errorf("invalid request to read %s", s)
}
func invalidJSONError(expected string) error {
return fmt.Errorf("invalid JSON input; expected %s", expected)
}
func invalidJSONErrorForType(expected string, t Type) error {
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
}
func unexpectedTokenError(jt *jsonToken) error {
switch jt.t {
case jttInt32, jttInt64, jttDouble:
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
case jttString:
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
case jttBool:
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
case jttNull:
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
case jttEOF:
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
default:
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
}
}
func nestingDepthError(p, depth int) error {
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
}
+604
View File
@@ -0,0 +1,604 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"io"
)
type ejvrState struct {
mode mode
vType Type
depth int
}
// extJSONValueReader is for reading extended JSON.
type extJSONValueReader struct {
p *extJSONParser
stack []ejvrState
frame int
}
// NewExtJSONValueReader returns a ValueReader that reads Extended JSON values
// from r. If canonicalOnly is true, reading values from the ValueReader returns
// an error if the Extended JSON was not marshaled in canonical mode.
func NewExtJSONValueReader(r io.Reader, canonicalOnly bool) (ValueReader, error) {
return newExtJSONValueReader(r, canonicalOnly)
}
func newExtJSONValueReader(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
ejvr := new(extJSONValueReader)
return ejvr.reset(r, canonicalOnly)
}
func (ejvr *extJSONValueReader) reset(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
p := newExtJSONParser(r, canonicalOnly)
typ, err := p.peekType()
if err != nil {
return nil, ErrInvalidJSON
}
var m mode
switch typ {
case TypeEmbeddedDocument:
m = mTopLevel
case TypeArray:
m = mArray
default:
m = mValue
}
stack := make([]ejvrState, 1, 5)
stack[0] = ejvrState{
mode: m,
vType: typ,
}
return &extJSONValueReader{
p: p,
stack: stack,
}, nil
}
func (ejvr *extJSONValueReader) advanceFrame() {
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
length := len(ejvr.stack)
if length+1 >= cap(ejvr.stack) {
// double it
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
copy(buf, ejvr.stack)
ejvr.stack = buf
}
ejvr.stack = ejvr.stack[:length+1]
}
ejvr.frame++
// Clean the stack
ejvr.stack[ejvr.frame].mode = 0
ejvr.stack[ejvr.frame].vType = 0
ejvr.stack[ejvr.frame].depth = 0
}
func (ejvr *extJSONValueReader) pushDocument() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mDocument
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
}
func (ejvr *extJSONValueReader) pushCodeWithScope() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mCodeWithScope
}
func (ejvr *extJSONValueReader) pushArray() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mArray
}
func (ejvr *extJSONValueReader) push(m mode, t Type) {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = m
ejvr.stack[ejvr.frame].vType = t
}
func (ejvr *extJSONValueReader) pop() {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
ejvr.frame--
case mDocument, mArray, mCodeWithScope:
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (ejvr *extJSONValueReader) skipObject() {
// read entire object until depth returns to 0 (last ending } or ] seen)
depth := 1
for depth > 0 {
ejvr.p.advanceState()
// If object is empty, raise depth and continue. When emptyObject is true, the
// parser has already read both the opening and closing brackets of an empty
// object ("{}"), so the next valid token will be part of the parent document,
// not part of the nested document.
//
// If there is a comma, there are remaining fields, emptyObject must be set back
// to false, and comma must be skipped with advanceState().
if ejvr.p.emptyObject {
if ejvr.p.s == jpsSawComma {
ejvr.p.emptyObject = false
ejvr.p.advanceState()
}
depth--
continue
}
switch ejvr.p.s {
case jpsSawBeginObject, jpsSawBeginArray:
depth++
case jpsSawEndObject, jpsSawEndArray:
depth--
}
}
}
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvr.stack[ejvr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if ejvr.frame != 0 {
te.parent = ejvr.stack[ejvr.frame-1].mode
}
return te
}
func (ejvr *extJSONValueReader) typeError(t Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
}
func (ejvr *extJSONValueReader) ensureElementValue(t Type, destination mode, callerName string, addModes ...mode) error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != t {
return ejvr.typeError(t)
}
default:
modes := []mode{mElement, mValue}
if addModes != nil {
modes = append(modes, addModes...)
}
return ejvr.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvr *extJSONValueReader) Type() Type {
return ejvr.stack[ejvr.frame].vType
}
func (ejvr *extJSONValueReader) Skip() error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
default:
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
defer ejvr.pop()
t := ejvr.stack[ejvr.frame].vType
switch t {
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
// read entire array, doc or CodeWithScope
ejvr.skipObject()
default:
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
}
return nil
}
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel: // allow reading array from top level
case mArray:
return ejvr, nil
default:
if err := ejvr.ensureElementValue(TypeArray, mArray, "ReadArray", mTopLevel, mArray); err != nil {
return nil, err
}
}
ejvr.pushArray()
return ejvr, nil
}
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := ejvr.ensureElementValue(TypeBinary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
v, err := ejvr.p.readValue(TypeBinary)
if err != nil {
return nil, 0, err
}
b, btype, err = v.parseBinary()
ejvr.pop()
return b, btype, err
}
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
if err := ejvr.ensureElementValue(TypeBoolean, 0, "ReadBoolean"); err != nil {
return false, err
}
v, err := ejvr.p.readValue(TypeBoolean)
if err != nil {
return false, err
}
if v.t != TypeBoolean {
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
}
ejvr.pop()
return v.v.(bool), nil
}
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel:
return ejvr, nil
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != TypeEmbeddedDocument {
return nil, ejvr.typeError(TypeEmbeddedDocument)
}
ejvr.pushDocument()
return ejvr, nil
default:
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
}
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err = ejvr.ensureElementValue(TypeCodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
v, err := ejvr.p.readValue(TypeCodeWithScope)
if err != nil {
return "", nil, err
}
code, err = v.parseJavascript()
ejvr.pushCodeWithScope()
return code, ejvr, err
}
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid ObjectID, err error) {
if err = ejvr.ensureElementValue(TypeDBPointer, 0, "ReadDBPointer"); err != nil {
return "", NilObjectID, err
}
v, err := ejvr.p.readValue(TypeDBPointer)
if err != nil {
return "", NilObjectID, err
}
ns, oid, err = v.parseDBPointer()
ejvr.pop()
return ns, oid, err
}
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
if err := ejvr.ensureElementValue(TypeDateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeDateTime)
if err != nil {
return 0, err
}
d, err := v.parseDateTime()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDecimal128() (Decimal128, error) {
if err := ejvr.ensureElementValue(TypeDecimal128, 0, "ReadDecimal128"); err != nil {
return Decimal128{}, err
}
v, err := ejvr.p.readValue(TypeDecimal128)
if err != nil {
return Decimal128{}, err
}
d, err := v.parseDecimal128()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
if err := ejvr.ensureElementValue(TypeDouble, 0, "ReadDouble"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeDouble)
if err != nil {
return 0, err
}
d, err := v.parseDouble()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
if err := ejvr.ensureElementValue(TypeInt32, 0, "ReadInt32"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeInt32)
if err != nil {
return 0, err
}
i, err := v.parseInt32()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
if err := ejvr.ensureElementValue(TypeInt64, 0, "ReadInt64"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeInt64)
if err != nil {
return 0, err
}
i, err := v.parseInt64()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
if err = ejvr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeJavaScript)
if err != nil {
return "", err
}
code, err = v.parseJavascript()
ejvr.pop()
return code, err
}
func (ejvr *extJSONValueReader) ReadMaxKey() error {
if err := ejvr.ensureElementValue(TypeMaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeMaxKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("max")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadMinKey() error {
if err := ejvr.ensureElementValue(TypeMinKey, 0, "ReadMinKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeMinKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("min")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadNull() error {
if err := ejvr.ensureElementValue(TypeNull, 0, "ReadNull"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeNull)
if err != nil {
return err
}
if v.t != TypeNull {
return fmt.Errorf("expected type null but got type %s", v.t)
}
ejvr.pop()
return nil
}
func (ejvr *extJSONValueReader) ReadObjectID() (ObjectID, error) {
if err := ejvr.ensureElementValue(TypeObjectID, 0, "ReadObjectID"); err != nil {
return ObjectID{}, err
}
v, err := ejvr.p.readValue(TypeObjectID)
if err != nil {
return ObjectID{}, err
}
oid, err := v.parseObjectID()
ejvr.pop()
return oid, err
}
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
if err = ejvr.ensureElementValue(TypeRegex, 0, "ReadRegex"); err != nil {
return "", "", err
}
v, err := ejvr.p.readValue(TypeRegex)
if err != nil {
return "", "", err
}
pattern, options, err = v.parseRegex()
ejvr.pop()
return pattern, options, err
}
func (ejvr *extJSONValueReader) ReadString() (string, error) {
if err := ejvr.ensureElementValue(TypeString, 0, "ReadString"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeString)
if err != nil {
return "", err
}
if v.t != TypeString {
return "", fmt.Errorf("expected type string but got type %s", v.t)
}
ejvr.pop()
return v.v.(string), nil
}
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
if err = ejvr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeSymbol)
if err != nil {
return "", err
}
symbol, err = v.parseSymbol()
ejvr.pop()
return symbol, err
}
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err = ejvr.ensureElementValue(TypeTimestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
v, err := ejvr.p.readValue(TypeTimestamp)
if err != nil {
return 0, 0, err
}
t, i, err = v.parseTimestamp()
ejvr.pop()
return t, i, err
}
func (ejvr *extJSONValueReader) ReadUndefined() error {
if err := ejvr.ensureElementValue(TypeUndefined, 0, "ReadUndefined"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeUndefined)
if err != nil {
return err
}
err = v.parseUndefined()
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
name, t, err := ejvr.p.readKey()
if err != nil {
if errors.Is(err, ErrEOD) {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
return "", nil, err
}
}
ejvr.pop()
}
return "", nil, err
}
ejvr.push(mElement, t)
return name, ejvr, nil
}
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mArray:
default:
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := ejvr.p.peekType()
if err != nil {
if errors.Is(err, ErrEOA) {
ejvr.pop()
}
return nil, err
}
ejvr.push(mValue, t)
return ejvr, nil
}
+223
View File
@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
+485
View File
@@ -0,0 +1,485 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/base64"
"errors"
"fmt"
"math"
"strconv"
"time"
)
func wrapperKeyBSONType(key string) Type {
switch key {
case "$numberInt":
return TypeInt32
case "$numberLong":
return TypeInt64
case "$oid":
return TypeObjectID
case "$symbol":
return TypeSymbol
case "$numberDouble":
return TypeDouble
case "$numberDecimal":
return TypeDecimal128
case "$binary":
return TypeBinary
case "$code":
return TypeJavaScript
case "$scope":
return TypeCodeWithScope
case "$timestamp":
return TypeTimestamp
case "$regularExpression":
return TypeRegex
case "$dbPointer":
return TypeDBPointer
case "$date":
return TypeDateTime
case "$minKey":
return TypeMinKey
case "$maxKey":
return TypeMaxKey
case "$undefined":
return TypeUndefined
}
return TypeEmbeddedDocument
}
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
if ejv.t != TypeEmbeddedDocument {
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
}
binObj := ejv.v.(*extJSONObject)
bFound := false
stFound := false
for i, key := range binObj.keys {
val := binObj.values[i]
switch key {
case "base64":
if bFound {
return nil, 0, errors.New("duplicate base64 key in $binary")
}
if val.t != TypeString {
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
}
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
}
b = base64Bytes
bFound = true
case "subType":
if stFound {
return nil, 0, errors.New("duplicate subType key in $binary")
}
if val.t != TypeString {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}
i, err := strconv.ParseUint(val.v.(string), 16, 8)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
}
subType = byte(i)
stFound = true
default:
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
}
}
if !bFound {
return nil, 0, errors.New("missing base64 field in $binary object")
}
if !stFound {
return nil, 0, errors.New("missing subType field in $binary object")
}
return b, subType, nil
}
func (ejv *extJSONValue) parseDBPointer() (ns string, oid ObjectID, err error) {
if ejv.t != TypeEmbeddedDocument {
return "", NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
}
dbpObj := ejv.v.(*extJSONObject)
oidFound := false
nsFound := false
for i, key := range dbpObj.keys {
val := dbpObj.values[i]
switch key {
case "$ref":
if nsFound {
return "", NilObjectID, errors.New("duplicate $ref key in $dbPointer")
}
if val.t != TypeString {
return "", NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
}
ns = val.v.(string)
nsFound = true
case "$id":
if oidFound {
return "", NilObjectID, errors.New("duplicate $id key in $dbPointer")
}
if val.t != TypeString {
return "", NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
}
oid, err = ObjectIDFromHex(val.v.(string))
if err != nil {
return "", NilObjectID, err
}
oidFound = true
default:
return "", NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
}
}
if !nsFound {
return "", oid, errors.New("missing $ref field in $dbPointer object")
}
if !oidFound {
return "", oid, errors.New("missing $id field in $dbPointer object")
}
return ns, oid, nil
}
const (
rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
)
var timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"}
func (ejv *extJSONValue) parseDateTime() (int64, error) {
switch ejv.t {
case TypeInt32:
return int64(ejv.v.(int32)), nil
case TypeInt64:
return ejv.v.(int64), nil
case TypeString:
return parseDatetimeString(ejv.v.(string))
case TypeEmbeddedDocument:
return parseDatetimeObject(ejv.v.(*extJSONObject))
default:
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
}
}
func parseDatetimeString(data string) (int64, error) {
var t time.Time
var err error
// try acceptable time formats until one matches
for _, format := range timeFormats {
t, err = time.Parse(format, data)
if err == nil {
break
}
}
if err != nil {
return 0, fmt.Errorf("invalid $date value string: %s", data)
}
return int64(NewDateTimeFromTime(t)), nil
}
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
dFound := false
for i, key := range data.keys {
val := data.values[i]
switch key {
case "$numberLong":
if dFound {
return 0, errors.New("duplicate $numberLong key in $date")
}
if val.t != TypeString {
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
}
d, err = val.parseInt64()
if err != nil {
return 0, err
}
dFound = true
default:
return 0, fmt.Errorf("invalid key in $date object: %s", key)
}
}
if !dFound {
return 0, errors.New("missing $numberLong field in $date object")
}
return d, nil
}
func (ejv *extJSONValue) parseDecimal128() (Decimal128, error) {
if ejv.t != TypeString {
return Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
}
d, err := ParseDecimal128(ejv.v.(string))
if err != nil {
return Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
}
return d, nil
}
func (ejv *extJSONValue) parseDouble() (float64, error) {
if ejv.t == TypeDouble {
return ejv.v.(float64), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
}
switch ejv.v.(string) {
case "Infinity":
return math.Inf(1), nil
case "-Infinity":
return math.Inf(-1), nil
case "NaN":
return math.NaN(), nil
}
f, err := strconv.ParseFloat(ejv.v.(string), 64)
if err != nil {
return 0, err
}
return f, nil
}
func (ejv *extJSONValue) parseInt32() (int32, error) {
if ejv.t == TypeInt32 {
return ejv.v.(int32), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
if i < math.MinInt32 || i > math.MaxInt32 {
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
}
return int32(i), nil
}
func (ejv *extJSONValue) parseInt64() (int64, error) {
if ejv.t == TypeInt64 {
return ejv.v.(int64), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
return i, nil
}
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
if ejv.t != TypeString {
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
if ejv.t != TypeInt32 {
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
}
if ejv.v.(int32) != 1 {
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
}
return nil
}
func (ejv *extJSONValue) parseObjectID() (ObjectID, error) {
if ejv.t != TypeString {
return NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
}
return ObjectIDFromHex(ejv.v.(string))
}
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
if ejv.t != TypeEmbeddedDocument {
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
}
regexObj := ejv.v.(*extJSONObject)
patFound := false
optFound := false
for i, key := range regexObj.keys {
val := regexObj.values[i]
switch key {
case "pattern":
if patFound {
return "", "", errors.New("duplicate pattern key in $regularExpression")
}
if val.t != TypeString {
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
}
pattern = val.v.(string)
patFound = true
case "options":
if optFound {
return "", "", errors.New("duplicate options key in $regularExpression")
}
if val.t != TypeString {
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
}
options = val.v.(string)
optFound = true
default:
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
}
}
if !patFound {
return "", "", errors.New("missing pattern field in $regularExpression object")
}
if !optFound {
return "", "", errors.New("missing options field in $regularExpression object")
}
return pattern, options, nil
}
func (ejv *extJSONValue) parseSymbol() (string, error) {
if ejv.t != TypeString {
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
if ejv.t != TypeEmbeddedDocument {
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
}
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
if flag {
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
}
switch val.t {
case TypeInt32:
value := val.v.(int32)
if value < 0 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
case TypeInt64:
value := val.v.(int64)
if value < 0 || value > int64(math.MaxUint32) {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
default:
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
}
}
tsObj := ejv.v.(*extJSONObject)
tFound := false
iFound := false
for j, key := range tsObj.keys {
val := tsObj.values[j]
switch key {
case "t":
if t, err = handleKey(key, val, tFound); err != nil {
return 0, 0, err
}
tFound = true
case "i":
if i, err = handleKey(key, val, iFound); err != nil {
return 0, 0, err
}
iFound = true
default:
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
}
}
if !tFound {
return 0, 0, errors.New("missing t field in $timestamp object")
}
if !iFound {
return 0, 0, errors.New("missing i field in $timestamp object")
}
return t, i, nil
}
func (ejv *extJSONValue) parseUndefined() error {
if ejv.t != TypeBoolean {
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
}
if !ejv.v.(bool) {
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
}
return nil
}
+690
View File
@@ -0,0 +1,690 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"math"
"sort"
"strconv"
"strings"
"time"
"unicode/utf8"
)
type ejvwState struct {
mode mode
}
type extJSONValueWriter struct {
w io.Writer
buf []byte
stack []ejvwState
frame int64
canonical bool
escapeHTML bool
newlines bool
}
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) ValueWriter {
// Enable newlines for all Extended JSON value writers created by NewExtJSONValueWriter. We
// expect these value writers to be used with an Encoder, which should add newlines after
// encoded Extended JSON documents.
return newExtJSONWriter(w, canonical, escapeHTML, true)
}
func newExtJSONWriter(w io.Writer, canonical, escapeHTML, newlines bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
w: w,
buf: []byte{},
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
newlines: newlines,
}
}
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
buf: buf,
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
if ejvw.stack == nil {
ejvw.stack = make([]ejvwState, 1, 5)
}
ejvw.stack = ejvw.stack[:1]
ejvw.stack[0] = ejvwState{mode: mTopLevel}
ejvw.canonical = canonical
ejvw.escapeHTML = escapeHTML
ejvw.frame = 0
ejvw.buf = buf
ejvw.w = nil
}
func (ejvw *extJSONValueWriter) advanceFrame() {
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
length := len(ejvw.stack)
if length+1 >= cap(ejvw.stack) {
// double it
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
copy(buf, ejvw.stack)
ejvw.stack = buf
}
ejvw.stack = ejvw.stack[:length+1]
}
ejvw.frame++
}
func (ejvw *extJSONValueWriter) push(m mode) {
ejvw.advanceFrame()
ejvw.stack[ejvw.frame].mode = m
}
func (ejvw *extJSONValueWriter) pop() {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
ejvw.frame--
case mDocument, mArray, mCodeWithScope:
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvw.stack[ejvw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if ejvw.frame != 0 {
te.parent = ejvw.stack[ejvw.frame-1].mode
}
return te
}
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return ejvw.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
var s string
if quotes {
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
} else {
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '[')
ejvw.push(mArray)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
return ejvw.WriteBinaryWithSubtype(b, 0x00)
}
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$binary":{"base64":"`)
buf.WriteString(base64.StdEncoding.EncodeToString(b))
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
var buf bytes.Buffer
buf.WriteString(`{"$code":`)
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
buf.WriteString(`,"$scope":{`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.push(mCodeWithScope)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$dbPointer":{"$ref":"`)
buf.WriteString(ns)
buf.WriteString(`","$id":{"$oid":"`)
buf.WriteString(oid.Hex())
buf.WriteString(`"}}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
return err
}
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
ejvw.writeExtendedSingleValue("date", s, false)
} else {
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDecimal128(d Decimal128) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
if ejvw.stack[ejvw.frame].mode == mTopLevel {
ejvw.buf = append(ejvw.buf, '{')
return ejvw, nil
}
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '{')
ejvw.push(mDocument)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
return err
}
s := formatDouble(f)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberDouble", s, true)
} else {
switch s {
case "Infinity":
fallthrough
case "-Infinity":
fallthrough
case "NaN":
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
return err
}
s := strconv.FormatInt(int64(i), 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberInt", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
return err
}
s := strconv.FormatInt(i, 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberLong", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("code", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("maxKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMinKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("minKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteNull() error {
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte("null")...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteObjectID(oid ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
return err
}
options = sortStringAlphebeticAscending(options)
var buf bytes.Buffer
buf.WriteString(`{"$regularExpression":{"pattern":`)
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
buf.WriteString(`,"options":`)
writeStringWithEscapes(options, &buf, ejvw.escapeHTML)
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteString(s string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$timestamp":{"t":`)
buf.WriteString(strconv.FormatUint(uint64(t), 10))
buf.WriteString(`,"i":`)
buf.WriteString(strconv.FormatUint(uint64(i), 10))
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteUndefined() error {
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("undefined", "true", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
var buf bytes.Buffer
writeStringWithEscapes(key, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...)
ejvw.push(mElement)
default:
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
default:
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
}
// close the document
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = '}'
} else {
ejvw.buf = append(ejvw.buf, '}')
}
switch ejvw.stack[ejvw.frame].mode {
case mCodeWithScope:
ejvw.buf = append(ejvw.buf, '}')
fallthrough
case mDocument:
ejvw.buf = append(ejvw.buf, ',')
case mTopLevel:
// If the value writer has newlines enabled, end top-level documents with a newline so that
// multiple documents encoded to the same writer are separated by newlines. That matches the
// Go json.Encoder behavior and also works with NewExtJSONValueReader.
if ejvw.newlines {
ejvw.buf = append(ejvw.buf, '\n')
}
if ejvw.w != nil {
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
return err
}
ejvw.buf = ejvw.buf[:0]
}
}
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
ejvw.push(mValue)
default:
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
// close the array
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = ']'
} else {
ejvw.buf = append(ejvw.buf, ']')
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
default:
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
}
return nil
}
func formatDouble(f float64) string {
var s string
switch {
case math.IsInf(f, 1):
s = "Infinity"
case math.IsInf(f, -1):
s = "-Infinity"
case math.IsNaN(f):
s = "NaN"
default:
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
var hexChars = "0123456789abcdef"
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
case '\b':
buf.WriteByte('\\')
buf.WriteByte('b')
case '\f':
buf.WriteByte('\\')
buf.WriteByte('f')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hexChars[b>>4])
buf.WriteByte(hexChars[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hexChars[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
}
type sortableString []rune
func (ss sortableString) Len() int {
return len(ss)
}
func (ss sortableString) Less(i, j int) bool {
return ss[i] < ss[j]
}
func (ss sortableString) Swap(i, j int) {
ss[i], ss[j] = ss[j], ss[i]
}
func sortStringAlphebeticAscending(s string) string {
ss := sortableString([]rune(s))
sort.Sort(ss)
return string([]rune(ss))
}
+532
View File
@@ -0,0 +1,532 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"strconv"
"unicode"
"unicode/utf16"
)
type jsonTokenType byte
const (
jttBeginObject jsonTokenType = iota
jttEndObject
jttBeginArray
jttEndArray
jttColon
jttComma
jttInt32
jttInt64
jttDouble
jttString
jttBool
jttNull
jttEOF
)
type jsonToken struct {
t jsonTokenType
v any
p int
}
type jsonScanner struct {
r io.Reader
buf []byte
pos int
lastReadErr error
}
// nextToken returns the next JSON token if one exists. A token is a character
// of the JSON grammar, a number, a string, or a literal.
func (js *jsonScanner) nextToken() (*jsonToken, error) {
c, err := js.readNextByte()
// keep reading until a non-space is encountered (break on read error or EOF)
for isWhiteSpace(c) && err == nil {
c, err = js.readNextByte()
}
if errors.Is(err, io.EOF) {
return &jsonToken{t: jttEOF}, nil
} else if err != nil {
return nil, err
}
// switch on the character
switch c {
case '{':
return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil
case '}':
return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil
case '[':
return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil
case ']':
return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil
case ':':
return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil
case ',':
return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil
case '"': // RFC-8259 only allows for double quotes (") not single (')
return js.scanString()
default:
// check if it's a number
switch {
case c == '-' || isDigit(c):
return js.scanNumber(c)
case c == 't' || c == 'f' || c == 'n':
// maybe a literal
return js.scanLiteral(c)
default:
return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c)
}
}
}
// readNextByte attempts to read the next byte from the buffer. If the buffer
// has been exhausted, this function calls readIntoBuf, thus refilling the
// buffer and resetting the read position to 0
func (js *jsonScanner) readNextByte() (byte, error) {
if js.pos >= len(js.buf) {
err := js.readIntoBuf()
if err != nil {
return 0, err
}
}
b := js.buf[js.pos]
js.pos++
return b, nil
}
// readNNextBytes reads n bytes into dst, starting at offset
func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error {
var err error
for i := 0; i < n; i++ {
dst[i+offset], err = js.readNextByte()
if err != nil {
return err
}
}
return nil
}
// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer
func (js *jsonScanner) readIntoBuf() error {
if js.lastReadErr != nil {
js.buf = js.buf[:0]
js.pos = 0
return js.lastReadErr
}
if cap(js.buf) == 0 {
js.buf = make([]byte, 0, 512)
}
n, err := js.r.Read(js.buf[:cap(js.buf)])
if err != nil {
js.lastReadErr = err
if n > 0 {
err = nil
}
}
js.buf = js.buf[:n]
js.pos = 0
return err
}
func isWhiteSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
func isDigit(c byte) bool {
return unicode.IsDigit(rune(c))
}
func isValueTerminator(c byte) bool {
return c == ',' || c == '}' || c == ']' || isWhiteSpace(c)
}
// getu4 decodes the 4-byte hex sequence from the beginning of s, returning the hex value as a rune,
// or it returns -1. Note that the "\u" from the unicode escape sequence should not be present.
// It is copied and lightly modified from the Go JSON decode function at
// https://github.com/golang/go/blob/1b0a0316802b8048d69da49dc23c5a5ab08e8ae8/src/encoding/json/decode.go#L1169-L1188
func getu4(s []byte) rune {
if len(s) < 4 {
return -1
}
var r rune
for _, c := range s[:4] {
switch {
case '0' <= c && c <= '9':
c -= '0'
case 'a' <= c && c <= 'f':
c = c - 'a' + 10
case 'A' <= c && c <= 'F':
c = c - 'A' + 10
default:
return -1
}
r = r*16 + rune(c)
}
return r
}
// scanString reads from an opening '"' to a closing '"' and handles escaped characters
func (js *jsonScanner) scanString() (*jsonToken, error) {
var b bytes.Buffer
var c byte
var err error
p := js.pos - 1
for {
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
evalNextChar:
switch c {
case '\\':
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
evalNextEscapeChar:
switch c {
case '"', '\\', '/':
b.WriteByte(c)
case 'b':
b.WriteByte('\b')
case 'f':
b.WriteByte('\f')
case 'n':
b.WriteByte('\n')
case 'r':
b.WriteByte('\r')
case 't':
b.WriteByte('\t')
case 'u':
us := make([]byte, 4)
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
rn := getu4(us)
// If the rune we just decoded is the high or low value of a possible surrogate pair,
// try to decode the next sequence as the low value of a surrogate pair. We're
// expecting the next sequence to be another Unicode escape sequence (e.g. "\uDD1E"),
// but need to handle cases where the input is not a valid surrogate pair.
// For more context on unicode surrogate pairs, see:
// https://www.christianfscott.com/rust-chars-vs-go-runes/
// https://www.unicode.org/glossary/#high_surrogate_code_point
if utf16.IsSurrogate(rn) {
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
// If the next value isn't the beginning of a backslash escape sequence, write
// the Unicode replacement character for the surrogate value and goto the
// beginning of the next char eval block.
if c != '\\' {
b.WriteRune(unicode.ReplacementChar)
goto evalNextChar
}
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
// If the next value isn't the beginning of a unicode escape sequence, write the
// Unicode replacement character for the surrogate value and goto the beginning
// of the next escape char eval block.
if c != 'u' {
b.WriteRune(unicode.ReplacementChar)
goto evalNextEscapeChar
}
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
rn2 := getu4(us)
// Try to decode the pair of runes as a utf16 surrogate pair. If that fails, write
// the Unicode replacement character for the surrogate value and the 2nd decoded rune.
if rnPair := utf16.DecodeRune(rn, rn2); rnPair != unicode.ReplacementChar {
b.WriteRune(rnPair)
} else {
b.WriteRune(unicode.ReplacementChar)
b.WriteRune(rn2)
}
break
}
b.WriteRune(rn)
default:
return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c)
}
case '"':
return &jsonToken{t: jttString, v: b.String(), p: p}, nil
default:
b.WriteByte(c)
}
}
}
// scanLiteral reads an unquoted sequence of characters and determines if it is one of
// three valid JSON literals (true, false, null); if so, it returns the appropriate
// jsonToken; otherwise, it returns an error
func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) {
p := js.pos - 1
lit := make([]byte, 4)
lit[0] = first
err := js.readNNextBytes(lit, 3, 1)
if err != nil {
return nil, err
}
c5, err := js.readNextByte()
switch {
case bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)):
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: true, p: p}, nil
case bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)):
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttNull, v: nil, p: p}, nil
case bytes.Equal([]byte("fals"), lit):
if c5 == 'e' {
c5, err = js.readNextByte()
if isValueTerminator(c5) || errors.Is(err, io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: false, p: p}, nil
}
}
}
return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit)
}
type numberScanState byte
const (
nssSawLeadingMinus numberScanState = iota
nssSawLeadingZero
nssSawIntegerDigits
nssSawDecimalPoint
nssSawFractionDigits
nssSawExponentLetter
nssSawExponentSign
nssSawExponentDigits
nssDone
nssInvalid
)
// scanNumber reads a JSON number (according to RFC-8259)
func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
var b bytes.Buffer
var s numberScanState
var c byte
var err error
t := jttInt64 // assume it's an int64 until the type can be determined
start := js.pos - 1
b.WriteByte(first)
switch first {
case '-':
s = nssSawLeadingMinus
case '0':
s = nssSawLeadingZero
default:
s = nssSawIntegerDigits
}
for {
c, err = js.readNextByte()
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
switch s {
case nssSawLeadingMinus:
switch c {
case '0':
s = nssSawLeadingZero
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawLeadingZero:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || errors.Is(err, io.EOF) {
s = nssDone
} else {
s = nssInvalid
}
}
case nssSawIntegerDigits:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
switch {
case isWhiteSpace(c) || errors.Is(err, io.EOF):
s = nssDone
case isDigit(c):
s = nssSawIntegerDigits
b.WriteByte(c)
default:
s = nssInvalid
}
}
case nssSawDecimalPoint:
t = jttDouble
if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawFractionDigits:
switch c {
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
switch {
case isWhiteSpace(c) || errors.Is(err, io.EOF):
s = nssDone
case isDigit(c):
s = nssSawFractionDigits
b.WriteByte(c)
default:
s = nssInvalid
}
}
case nssSawExponentLetter:
t = jttDouble
switch c {
case '+', '-':
s = nssSawExponentSign
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentSign:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawExponentDigits:
switch c {
case '}', ']', ',':
s = nssDone
default:
switch {
case isWhiteSpace(c) || errors.Is(err, io.EOF):
s = nssDone
case isDigit(c):
s = nssSawExponentDigits
b.WriteByte(c)
default:
s = nssInvalid
}
}
}
switch s {
case nssInvalid:
return nil, fmt.Errorf("invalid JSON number. Position: %d", start)
case nssDone:
js.pos = int(math.Max(0, float64(js.pos-1)))
if t != jttDouble {
v, err := strconv.ParseInt(b.String(), 10, 64)
if err == nil {
if v < math.MinInt32 || v > math.MaxInt32 {
return &jsonToken{t: jttInt64, v: v, p: start}, nil
}
return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil
}
}
v, err := strconv.ParseFloat(b.String(), 64)
if err != nil {
return nil, err
}
return &jsonToken{t: jttDouble, v: v, p: start}, nil
}
}
}
+293
View File
@@ -0,0 +1,293 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding"
"errors"
"fmt"
"reflect"
"strconv"
)
// mapCodec is the Codec used for map values.
type mapCodec struct {
// DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination
// value passed to Decode before unmarshaling BSON documents into them.
decodeZerosMap bool
// EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of
// BSON null.
encodeNilAsEmpty bool
// EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name
// strings using fmt.Sprintf() instead of the default string conversion logic.
encodeKeysWithStringer bool
}
// KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
// This applies to types used as map keys and is similar to encoding.TextMarshaler.
type KeyMarshaler interface {
MarshalKey() (key string, err error)
}
// KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
// of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
//
// UnmarshalKey must be able to decode the form generated by MarshalKey.
// UnmarshalKey must copy the text if it wishes to retain the text
// after returning.
type KeyUnmarshaler interface {
UnmarshalKey(key string) error
}
// EncodeValue is the ValueEncoder for map[*]* types.
func (mc *mapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Map {
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
if val.IsNil() && !mc.encodeNilAsEmpty && !ec.nilMapAsEmpty {
// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
// so just continue. The operations on a map reflection value are valid, so we can call
// MapKeys within mapEncodeValue without a problem.
err := vw.WriteNull()
if err == nil {
return nil
}
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
err = mc.encodeMapElements(ec, dw, val, nil)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// encodeMapElements handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (mc *mapCodec) encodeMapElements(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
keys := val.MapKeys()
for _, key := range keys {
keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
if err != nil {
return err
}
if collisionFn != nil && collisionFn(keyStr) {
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
}
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}
vw, err := dw.WriteDocumentElement(keyStr)
if err != nil {
return err
}
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return nil
}
// DecodeValue is the ValueDecoder for map[string/decimal]* types.
func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
switch vrType := vr.Type(); vrType {
case Type(0), TypeEmbeddedDocument:
case TypeNull:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case TypeUndefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeMap(val.Type()))
}
if val.Len() > 0 && (mc.decodeZerosMap || dc.zeroMaps) {
clearMap(val)
}
eType := val.Type().Elem()
decoder, err := dc.LookupDecoder(eType)
if err != nil {
return err
}
keyType := val.Type().Key()
for {
key, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return err
}
k, err := mc.decodeKey(key, keyType)
if err != nil {
return err
}
elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, eType)
if err != nil {
return newDecodeError(key, err)
}
val.SetMapIndex(k, elem)
}
return nil
}
func clearMap(m reflect.Value) {
var none reflect.Value
for _, k := range m.MapKeys() {
m.SetMapIndex(k, none)
}
}
func (mc *mapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
if mc.encodeKeysWithStringer || encodeKeysWithStringer {
return fmt.Sprint(val), nil
}
// keys of any string type are used directly
if val.Kind() == reflect.String {
return val.String(), nil
}
// KeyMarshalers are marshaled
if km, ok := val.Interface().(KeyMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalKey()
if err == nil {
return buf, nil
}
return "", err
}
// keys implement encoding.TextMarshaler are marshaled.
if km, ok := val.Interface().(encoding.TextMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalText()
if err != nil {
return "", err
}
return string(buf), nil
}
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(val.Uint(), 10), nil
}
return "", fmt.Errorf("unsupported key type: %v", val.Type())
}
var (
keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
)
func (mc *mapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
keyVal := reflect.ValueOf(key)
var err error
switch {
// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
case !mc.encodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(KeyUnmarshaler)
err = v.UnmarshalKey(key)
keyVal = keyVal.Elem()
// Try to decode encoding.TextUnmarshalers.
case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(encoding.TextUnmarshaler)
err = v.UnmarshalText([]byte(key))
keyVal = keyVal.Elem()
// Otherwise, go to type specific behavior
default:
switch keyType.Kind() {
case reflect.String:
keyVal = reflect.ValueOf(key).Convert(keyType)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, parseErr := strconv.ParseInt(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, parseErr := strconv.ParseUint(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
break
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Float32, reflect.Float64:
if mc.encodeKeysWithStringer {
parsed, err := strconv.ParseFloat(key, 64)
if err != nil {
return keyVal, fmt.Errorf("map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err)
}
keyVal = reflect.ValueOf(parsed)
break
}
fallthrough
default:
return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
}
}
return keyVal, err
}
+191
View File
@@ -0,0 +1,191 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/json"
"sync"
)
const defaultDstCap = 256
var extjPool = sync.Pool{
New: func() any {
return new(extJSONValueWriter)
},
}
// Marshaler is the interface implemented by types that can marshal themselves
// into a valid BSON document.
//
// Implementations of Marshaler must return a full BSON document. To create
// custom BSON marshaling behavior for individual values in a BSON document,
// implement the ValueMarshaler interface instead.
type Marshaler interface {
MarshalBSON() ([]byte, error)
}
// ValueMarshaler is the interface implemented by types that can marshal
// themselves into a valid BSON value. The format of the returned bytes must
// match the returned type.
//
// Implementations of ValueMarshaler must return an individual BSON value. To
// create custom BSON marshaling behavior for an entire BSON document, implement
// the Marshaler interface instead.
type ValueMarshaler interface {
MarshalBSONValue() (typ byte, data []byte, err error)
}
// Pool of buffers for marshalling BSON.
var bufPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
// Marshal returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed into a
// document, MarshalValue should be used instead.
//
// Marshal will use the default registry created by NewRegistry to recursively
// marshal val into a []byte. Marshal will inspect struct tags and alter the
// marshaling process accordingly.
func Marshal(val any) ([]byte, error) {
sw := bufPool.Get().(*bytes.Buffer)
defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum
// buffer to place back in the pool. We limit the size to 16MiB because
// that's the maximum wire message size supported by any current MongoDB
// server.
//
// Comment based on
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
//
// Recycle byte slices that are smaller than 16MiB and at least half
// occupied.
if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
bufPool.Put(sw)
}
}()
sw.Reset()
vw := getDocumentWriter(sw)
defer putDocumentWriter(vw)
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vw)
enc.SetRegistry(defaultRegistry)
err := enc.Encode(val)
if err != nil {
return nil, err
}
buf := append([]byte(nil), sw.Bytes()...)
return buf, nil
}
// MarshalValue returns the BSON encoding of val.
//
// MarshalValue will use bson.NewRegistry() to transform val into a BSON value. If val is a struct, this function will
// inspect struct tags and alter the marshalling process accordingly.
func MarshalValue(val any) (Type, []byte, error) {
sw := bufPool.Get().(*bytes.Buffer)
defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum
// buffer to place back in the pool. We limit the size to 16MiB because
// that's the maximum wire message size supported by any current MongoDB
// server.
//
// Comment based on
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
//
// Recycle byte slices that are smaller than 16MiB and at least half
// occupied.
if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
bufPool.Put(sw)
}
}()
sw.Reset()
vwFlusher := newDocumentWriter(sw)
vw, err := vwFlusher.WriteDocumentElement("")
if err != nil {
return 0, nil, err
}
// get an Encoder and encode the value
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vw)
enc.SetRegistry(defaultRegistry)
if err := enc.Encode(val); err != nil {
return 0, nil, err
}
// flush the bytes written because we cannot guarantee that a full document has been written
// after the flush, *sw will be in the format
// [value type, 0 (null byte to indicate end of empty element name), value bytes..]
if err := vwFlusher.Flush(); err != nil {
return 0, nil, err
}
typ := sw.Next(2)
clone := append([]byte{}, sw.Bytes()...) // Don't hand out a shared reference to byte buffer bytes
// and fully copy the data. The byte buffer is (potentially) reused
// and handing out only a reference to the bytes may lead to race-conditions with the buffer.
return Type(typ[0]), clone, nil
}
// MarshalExtJSON returns the extended JSON encoding of val.
func MarshalExtJSON(val any, canonical, escapeHTML bool) ([]byte, error) {
sw := sliceWriter(make([]byte, 0, defaultDstCap))
ejvw := extjPool.Get().(*extJSONValueWriter)
ejvw.reset(sw, canonical, escapeHTML)
ejvw.w = &sw
defer func() {
ejvw.buf = nil
ejvw.w = nil
extjPool.Put(ejvw)
}()
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(ejvw)
enc.ec = EncodeContext{Registry: defaultRegistry}
err := enc.Encode(val)
if err != nil {
return nil, err
}
return sw, nil
}
// IndentExtJSON will prefix and indent the provided extended JSON src and append it to dst.
func IndentExtJSON(dst *bytes.Buffer, src []byte, prefix, indent string) error {
return json.Indent(dst, src, prefix, indent)
}
// MarshalExtJSONIndent returns the extended JSON encoding of val with each line with prefixed
// and indented.
func MarshalExtJSONIndent(val any, canonical, escapeHTML bool, prefix, indent string) ([]byte, error) {
marshaled, err := MarshalExtJSON(val, canonical, escapeHTML)
if err != nil {
return nil, err
}
var buf bytes.Buffer
err = IndentExtJSON(&buf, marshaled, prefix, indent)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
+209
View File
@@ -0,0 +1,209 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"reflect"
)
var (
// ErrMgoSetZero may be returned from a SetBSON method to have the value set to its respective zero value.
ErrMgoSetZero = errors.New("set to zero")
tInt = reflect.TypeOf(int(0))
tM = reflect.TypeOf(M{})
tInterfaceSlice = reflect.TypeOf([]any{})
tGetter = reflect.TypeOf((*getter)(nil)).Elem()
tSetter = reflect.TypeOf((*setter)(nil)).Elem()
)
// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders.
func NewMgoRegistry() *Registry {
mapCodec := &mapCodec{
decodeZerosMap: true,
encodeNilAsEmpty: true,
encodeKeysWithStringer: true,
}
structCodec := &structCodec{
inlineMapEncoder: mapCodec,
decodeZeroStruct: true,
encodeOmitDefaultStruct: true,
allowUnexportedFields: true,
}
uintCodec := &uintCodec{encodeToMinSize: true}
reg := NewRegistry()
reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true})
reg.RegisterKindDecoder(reflect.String, ValueDecoderFunc(mgoStringDecodeValue))
reg.RegisterKindDecoder(reflect.Struct, structCodec)
reg.RegisterKindDecoder(reflect.Map, mapCodec)
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true})
reg.RegisterKindEncoder(reflect.Struct, structCodec)
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true})
reg.RegisterKindEncoder(reflect.Map, mapCodec)
reg.RegisterKindEncoder(reflect.Uint, uintCodec)
reg.RegisterKindEncoder(reflect.Uint8, uintCodec)
reg.RegisterKindEncoder(reflect.Uint16, uintCodec)
reg.RegisterKindEncoder(reflect.Uint32, uintCodec)
reg.RegisterKindEncoder(reflect.Uint64, uintCodec)
reg.RegisterTypeMapEntry(TypeInt32, tInt)
reg.RegisterTypeMapEntry(TypeDateTime, tTime)
reg.RegisterTypeMapEntry(TypeArray, tInterfaceSlice)
reg.RegisterTypeMapEntry(Type(0), tM)
reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tM)
reg.RegisterInterfaceEncoder(tGetter, ValueEncoderFunc(getterEncodeValue))
reg.RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(setterDecodeValue))
return reg
}
// NewRespectNilValuesMgoRegistry creates a new bson.Registry configured to behave like mgo/bson
// with RespectNilValues set to true.
func NewRespectNilValuesMgoRegistry() *Registry {
mapCodec := &mapCodec{
decodeZerosMap: true,
}
reg := NewMgoRegistry()
reg.RegisterKindDecoder(reflect.Map, mapCodec)
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false})
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{})
reg.RegisterKindEncoder(reflect.Map, mapCodec)
return reg
}
func mgoStringDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueDecoderError{
Name: "StringDecodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: reflect.Zero(val.Type()),
}
}
if vr.Type() == TypeObjectID {
oid, err := vr.ReadObjectID()
if err != nil {
return err
}
if dc.objectIDAsHexString {
val.SetString(oid.Hex())
} else {
val.SetString(string(oid[:]))
}
return nil
}
return (&stringCodec{}).DecodeValue(dc, vr, val)
}
// setter interface: a value implementing the bson.Setter interface will receive the BSON
// value via the SetBSON method during unmarshaling, and the object
// itself will not be changed as usual.
//
// If setting the value works, the method should return nil or alternatively
// ErrMgoSetZero to set the respective field to its zero value (nil for
// pointer types). If SetBSON returns a non-nil error, the unmarshalling
// procedure will stop and error out with the provided value.
//
// This interface is generally useful in pointer receivers, since the method
// will want to change the receiver. A type field that implements the Setter
// interface doesn't have to be a pointer, though.
//
// For example:
//
// type MyString string
//
// func (s *MyString) SetBSON(raw bson.RawValue) error {
// return raw.Unmarshal(s)
// }
type setter interface {
SetBSON(raw RawValue) error
}
// getter interface: a value implementing the bson.Getter interface will have its GetBSON
// method called when the given value has to be marshalled, and the result
// of this method will be marshaled in place of the actual object.
//
// If GetBSON returns return a non-nil error, the marshalling procedure
// will stop and error out with the provided value.
type getter interface {
GetBSON() (any, error)
}
// setterDecodeValue is the ValueDecoderFunc for Setter types.
func setterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) {
return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
val.Set(reflect.New(val.Type().Elem()))
}
if !val.Type().Implements(tSetter) {
if !val.CanAddr() {
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
val = val.Addr() // If the type doesn't implement the interface, a pointer to it must.
}
t, src, err := copyValueToBytes(vr)
if err != nil {
return err
}
m, ok := val.Interface().(setter)
if !ok {
return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
if err := m.SetBSON(RawValue{Type: t, Value: src}); err != nil {
if !errors.Is(err, ErrMgoSetZero) {
return err
}
val.Set(reflect.Zero(val.Type()))
}
return nil
}
// getterEncodeValue is the ValueEncoderFunc for Getter types.
func getterEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Getter
switch {
case !val.IsValid():
return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
case val.Type().Implements(tGetter):
// If Getter is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tGetter) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tGetter) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
}
m, ok := val.Interface().(getter)
if !ok {
return vw.WriteNull()
}
x, err := m.GetBSON()
if err != nil {
return err
}
if x == nil {
return vw.WriteNull()
}
vv := reflect.ValueOf(x)
encoder, err := ec.LookupEncoder(vv.Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, vv)
}
+82
View File
@@ -0,0 +1,82 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
)
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "Document"
case mArray:
str = "Array"
case mValue:
str = "Value"
case mElement:
str = "Element"
case mCodeWithScope:
str = "CodeWithScope"
case mSpacer:
str = "CodeWithScopeSpacer"
default:
str = "Unknown"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
type TransitionError struct {
name string
parent mode
current mode
destination mode
modes []mode
action string
}
func (te TransitionError) Error() string {
errString := fmt.Sprintf("%s can only %s", te.name, te.action)
if te.destination != mode(0) {
errString = fmt.Sprintf("%s a %s", errString, te.destination)
}
errString = fmt.Sprintf("%s while positioned on a", errString)
for ind, m := range te.modes {
if ind != 0 && len(te.modes) > 2 {
errString = fmt.Sprintf("%s,", errString)
}
if ind == len(te.modes)-1 && len(te.modes) > 1 {
errString = fmt.Sprintf("%s or", errString)
}
errString = fmt.Sprintf("%s %s", errString, m)
}
errString = fmt.Sprintf("%s but is positioned on a %s", errString, te.current)
if te.parent != mode(0) {
errString = fmt.Sprintf("%s with parent %s", errString, te.parent)
}
return errString
}
+211
View File
@@ -0,0 +1,211 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import (
"crypto/rand"
"encoding"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"sync/atomic"
"time"
)
// ErrInvalidHex indicates that a hex string cannot be converted to an ObjectID.
var ErrInvalidHex = errors.New("the provided hex string is not a valid ObjectID")
// ObjectID is the BSON ObjectID type.
type ObjectID [12]byte
// NilObjectID is the zero value for ObjectID.
var NilObjectID ObjectID
var (
objectIDCounter = readRandomUint32()
processUnique = processUniqueBytes()
)
var (
_ encoding.TextMarshaler = ObjectID{}
_ encoding.TextUnmarshaler = &ObjectID{}
)
// NewObjectID generates a new ObjectID.
func NewObjectID() ObjectID {
return NewObjectIDFromTimestamp(time.Now())
}
// NewObjectIDFromTimestamp generates a new ObjectID based on the given time.
func NewObjectIDFromTimestamp(timestamp time.Time) ObjectID {
var b [12]byte
binary.BigEndian.PutUint32(b[0:4], uint32(timestamp.Unix()))
copy(b[4:9], processUnique[:])
putUint24(b[9:12], atomic.AddUint32(&objectIDCounter, 1))
return b
}
// Timestamp extracts the time part of the ObjectId.
func (id ObjectID) Timestamp() time.Time {
unixSecs := binary.BigEndian.Uint32(id[0:4])
return time.Unix(int64(unixSecs), 0).UTC()
}
// Hex returns the hex encoding of the ObjectID as a string.
func (id ObjectID) Hex() string {
var buf [24]byte
hex.Encode(buf[:], id[:])
return string(buf[:])
}
func (id ObjectID) String() string {
return `ObjectID("` + id.Hex() + `")`
}
// IsZero returns true if id is the empty ObjectID.
func (id ObjectID) IsZero() bool {
return id == NilObjectID
}
// ObjectIDFromHex creates a new ObjectID from a hex string. It returns an error if the hex string is not a
// valid ObjectID.
func ObjectIDFromHex(s string) (ObjectID, error) {
if len(s) != 24 {
return NilObjectID, ErrInvalidHex
}
var oid [12]byte
_, err := hex.Decode(oid[:], []byte(s))
if err != nil {
return NilObjectID, err
}
return oid, nil
}
// MarshalText returns the ObjectID as UTF-8-encoded text. Implementing this allows us to use ObjectID
// as a map key when marshalling JSON. See https://pkg.go.dev/encoding#TextMarshaler
func (id ObjectID) MarshalText() ([]byte, error) {
var buf [24]byte
hex.Encode(buf[:], id[:])
return buf[:], nil
}
// UnmarshalText populates the byte slice with the ObjectID. If the byte slice
// is 24 bytes long, it will be populated with the hex representation of the
// ObjectID. If the byte slice is 12 bytes long, it will be populated with the
// BSON representation of the ObjectID. This method also accepts empty strings
// and decodes them as NilObjectID.
//
// For any other inputs, an error will be returned.
//
// Implementing this allows us to use ObjectID as a map key when unmarshalling
// JSON. See https://pkg.go.dev/encoding#TextUnmarshaler
func (id *ObjectID) UnmarshalText(b []byte) error {
// NB(charlie): The json package will use UnmarshalText instead of
// UnmarshalJSON if the value is a string.
// An empty string is not a valid ObjectID, but we treat it as a
// special value that decodes as NilObjectID.
if len(b) == 0 {
return nil
}
oid, err := ObjectIDFromHex(string(b))
if err != nil {
return err
}
*id = oid
return nil
}
// MarshalJSON returns the ObjectID as a string
func (id ObjectID) MarshalJSON() ([]byte, error) {
var buf [26]byte
buf[0] = '"'
hex.Encode(buf[1:25], id[:])
buf[25] = '"'
return buf[:], nil
}
// UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice
// is 24 bytes long, it will be populated with the hex representation of the
// ObjectID. If the byte slice is 12 bytes long, it will be populated with the
// BSON representation of the ObjectID. This method also accepts empty strings
// and decodes them as NilObjectID.
//
// As a special case UnmarshalJSON will decode a JSON object with key "$oid"
// that stores a hex encoded ObjectID: {"$oid": "65b3f7edd9bfca00daa6e3b31"}.
//
// For any other inputs, an error will be returned.
func (id *ObjectID) UnmarshalJSON(b []byte) error {
// Ignore "null" to keep parity with the standard library. Decoding a JSON
// null into a non-pointer ObjectID field will leave the field unchanged.
// For pointer values, encoding/json will set the pointer to nil and will
// not enter the UnmarshalJSON hook.
if string(b) == "null" {
return nil
}
// Handle string
if len(b) >= 2 && b[0] == '"' {
// TODO: fails because of error
return id.UnmarshalText(b[1 : len(b)-1])
}
if len(b) == 12 {
copy(id[:], b)
return nil
}
var v struct {
OID *string `json:"$oid"`
}
if err := json.Unmarshal(b, &v); err != nil {
return fmt.Errorf("failed to parse extended JSON ObjectID: %w", err)
}
if v.OID == nil {
return errors.New("not an extended JSON ObjectID")
}
i, err := ObjectIDFromHex(*v.OID)
if err != nil {
return err
}
*id = i
return nil
}
func processUniqueBytes() [5]byte {
var b [5]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}
return b
}
func readRandomUint32() uint32 {
var b [4]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}
return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
}
func putUint24(b []byte, v uint32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
+88
View File
@@ -0,0 +1,88 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
var (
_ ValueEncoder = &pointerCodec{}
_ ValueDecoder = &pointerCodec{}
)
// pointerCodec is the Codec used for pointers.
type pointerCodec struct {
ecache typeEncoderCache
dcache typeDecoderCache
}
// EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil
// or looking up an encoder for the type of value the pointer points to.
func (pc *pointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.Ptr {
if !val.IsValid() {
return vw.WriteNull()
}
return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
typ := val.Type()
if v, ok := pc.ecache.Load(typ); ok {
if v == nil {
return errNoEncoder{Type: typ}
}
return v.EncodeValue(ec, vw, val.Elem())
}
// TODO(charlie): handle concurrent requests for the same type
enc, err := ec.LookupEncoder(typ.Elem())
enc = pc.ecache.LoadOrStore(typ, enc)
if err != nil {
return err
}
return enc.EncodeValue(ec, vw, val.Elem())
}
// DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and
// using that to decode. If the BSON value is Null, this method will set the pointer to nil.
func (pc *pointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Ptr {
return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
typ := val.Type()
if vr.Type() == TypeNull {
val.Set(reflect.Zero(typ))
return vr.ReadNull()
}
if vr.Type() == TypeUndefined {
val.Set(reflect.Zero(typ))
return vr.ReadUndefined()
}
if val.IsNil() {
val.Set(reflect.New(typ.Elem()))
}
if v, ok := pc.dcache.Load(typ); ok {
if v == nil {
return errNoDecoder{Type: typ}
}
return v.DecodeValue(dc, vr, val.Elem())
}
// TODO(charlie): handle concurrent requests for the same type
dec, err := dc.LookupDecoder(typ.Elem())
dec = pc.dcache.LoadOrStore(typ, dec)
if err != nil {
return err
}
return dec.DecodeValue(dc, vr, val.Elem())
}
+381
View File
@@ -0,0 +1,381 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"time"
)
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// The following primitive types are similar to Go primitives for BSON types that
// do not have direct Go primitive representations.
// Binary represents a BSON binary value.
type Binary struct {
Subtype byte
Data []byte
}
// Equal compares bp to bp2 and returns true if they are equal.
func (bp Binary) Equal(bp2 Binary) bool {
if bp.Subtype != bp2.Subtype {
return false
}
return bytes.Equal(bp.Data, bp2.Data)
}
// IsZero returns if bp is the empty Binary.
func (bp Binary) IsZero() bool {
return bp.Subtype == 0 && len(bp.Data) == 0
}
// Undefined represents the BSON undefined value type.
type Undefined struct{}
// DateTime represents the BSON datetime value.
type DateTime int64
var (
_ json.Marshaler = DateTime(0)
_ json.Unmarshaler = (*DateTime)(nil)
)
// MarshalJSON marshal to time type.
func (d DateTime) MarshalJSON() ([]byte, error) {
return d.Time().UTC().MarshalJSON()
}
// UnmarshalJSON creates a bson.DateTime from a JSON string.
func (d *DateTime) UnmarshalJSON(data []byte) error {
// Ignore "null" so that we can distinguish between a "null" value and
// valid value that is the zero time (as reported by time.Time.IsZero).
if string(data) == "null" {
return nil
}
var t time.Time
if err := t.UnmarshalJSON(data); err != nil {
return err
}
*d = NewDateTimeFromTime(t)
return nil
}
// Time returns the date as a time type.
func (d DateTime) Time() time.Time {
return time.Unix(int64(d)/1000, int64(d)%1000*1000000)
}
// NewDateTimeFromTime creates a new DateTime from a Time.
func NewDateTimeFromTime(t time.Time) DateTime {
return DateTime(t.Unix()*1e3 + int64(t.Nanosecond())/1e6)
}
// Null represents the BSON null value.
type Null struct{}
// Regex represents a BSON regex value.
type Regex struct {
Pattern string
Options string
}
func (rp Regex) String() string {
return fmt.Sprintf(`{"pattern": "%s", "options": "%s"}`, rp.Pattern, rp.Options)
}
// Equal compares rp to rp2 and returns true if they are equal.
func (rp Regex) Equal(rp2 Regex) bool {
return rp.Pattern == rp2.Pattern && rp.Options == rp2.Options
}
// IsZero returns if rp is the empty Regex.
func (rp Regex) IsZero() bool {
return rp.Pattern == "" && rp.Options == ""
}
// DBPointer represents a BSON dbpointer value.
type DBPointer struct {
DB string
Pointer ObjectID
}
func (d DBPointer) String() string {
return fmt.Sprintf(`{"db": "%s", "pointer": "%s"}`, d.DB, d.Pointer)
}
// Equal compares d to d2 and returns true if they are equal.
func (d DBPointer) Equal(d2 DBPointer) bool {
return d == d2
}
// IsZero returns if d is the empty DBPointer.
func (d DBPointer) IsZero() bool {
return d.DB == "" && d.Pointer.IsZero()
}
// JavaScript represents a BSON JavaScript code value.
type JavaScript string
// Symbol represents a BSON symbol value.
type Symbol string
// CodeWithScope represents a BSON JavaScript code with scope value.
type CodeWithScope struct {
Code JavaScript
Scope any
}
func (cws CodeWithScope) String() string {
return fmt.Sprintf(`{"code": "%s", "scope": %v}`, cws.Code, cws.Scope)
}
// Timestamp represents a BSON timestamp value.
type Timestamp struct {
T uint32
I uint32
}
// After reports whether the time instant tp is after tp2.
func (tp Timestamp) After(tp2 Timestamp) bool {
return tp.T > tp2.T || (tp.T == tp2.T && tp.I > tp2.I)
}
// Before reports whether the time instant tp is before tp2.
func (tp Timestamp) Before(tp2 Timestamp) bool {
return tp.T < tp2.T || (tp.T == tp2.T && tp.I < tp2.I)
}
// Equal compares tp to tp2 and returns true if they are equal.
func (tp Timestamp) Equal(tp2 Timestamp) bool {
return tp.T == tp2.T && tp.I == tp2.I
}
// IsZero returns if tp is the zero Timestamp.
func (tp Timestamp) IsZero() bool {
return tp.T == 0 && tp.I == 0
}
// Compare compares the time instant tp with tp2. If tp is before tp2, it returns -1; if tp is after
// tp2, it returns +1; if they're the same, it returns 0.
func (tp Timestamp) Compare(tp2 Timestamp) int {
switch {
case tp.Equal(tp2):
return 0
case tp.Before(tp2):
return -1
default:
return +1
}
}
// MinKey represents the BSON minkey value.
type MinKey struct{}
// MaxKey represents the BSON maxkey value.
type MaxKey struct{}
// D is an ordered representation of a BSON document. This type should be used when the order of the elements matters,
// such as MongoDB command documents. If the order of the elements does not matter, an M should be used instead.
//
// Example usage:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
type D []E
func (d D) String() string {
b, err := MarshalExtJSON(d, true, false)
if err != nil {
return ""
}
return string(b)
}
// MarshalJSON encodes D into JSON.
func (d D) MarshalJSON() ([]byte, error) {
if d == nil {
return json.Marshal(nil)
}
var err error
var buf bytes.Buffer
buf.Write([]byte("{"))
enc := json.NewEncoder(&buf)
for i, e := range d {
err = enc.Encode(e.Key)
if err != nil {
return nil, err
}
buf.Write([]byte(":"))
err = enc.Encode(e.Value)
if err != nil {
return nil, err
}
if i < len(d)-1 {
buf.Write([]byte(","))
}
}
buf.Write([]byte("}"))
return json.RawMessage(buf.Bytes()).MarshalJSON()
}
// UnmarshalJSON decodes D from JSON.
func (d *D) UnmarshalJSON(b []byte) error {
dec := json.NewDecoder(bytes.NewReader(b))
t, err := dec.Token()
if err != nil {
return err
}
if t == nil {
*d = nil
return nil
}
if v, ok := t.(json.Delim); !ok || v != '{' {
return &json.UnmarshalTypeError{
Value: tokenString(t),
Type: reflect.TypeOf(D(nil)),
Offset: dec.InputOffset(),
}
}
*d, err = jsonDecodeD(dec)
return err
}
// E represents a BSON element for a D. It is usually used inside a D.
type E struct {
Key string
Value any
}
// M is an unordered representation of a BSON document. This type should be used when the order of the elements does not
// matter. This type is handled as a regular map[string]any when encoding and decoding. Elements will be
// serialized in an undefined, random order. If the order of the elements matters, a D should be used instead.
//
// Example usage:
//
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
type M map[string]any
func (m M) String() string {
b, err := MarshalExtJSON(m, true, false)
if err != nil {
return ""
}
return string(b)
}
// An A is an ordered representation of a BSON array.
//
// Example usage:
//
// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}}
type A []any
func jsonDecodeD(dec *json.Decoder) (D, error) {
res := D{}
for {
var e E
t, err := dec.Token()
if err != nil {
return nil, err
}
key, ok := t.(string)
if !ok {
break
}
e.Key = key
t, err = dec.Token()
if err != nil {
return nil, err
}
switch v := t.(type) {
case json.Delim:
switch v {
case '[':
e.Value, err = jsonDecodeSlice(dec)
if err != nil {
return nil, err
}
case '{':
e.Value, err = jsonDecodeD(dec)
if err != nil {
return nil, err
}
}
default:
e.Value = t
}
res = append(res, e)
}
return res, nil
}
func jsonDecodeSlice(dec *json.Decoder) ([]any, error) {
var res []any
done := false
for !done {
t, err := dec.Token()
if err != nil {
return nil, err
}
switch v := t.(type) {
case json.Delim:
switch v {
case '[':
a, err := jsonDecodeSlice(dec)
if err != nil {
return nil, err
}
res = append(res, a)
case '{':
d, err := jsonDecodeD(dec)
if err != nil {
return nil, err
}
res = append(res, d)
default:
done = true
}
default:
res = append(res, t)
}
}
return res, nil
}
func tokenString(t json.Token) string {
switch v := t.(type) {
case json.Delim:
switch v {
case '{':
return "object"
case '[':
return "array"
}
case bool:
return "bool"
case float64:
return "number"
case json.Number, string:
return "string"
}
return "unknown"
}
+91
View File
@@ -0,0 +1,91 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
)
var (
tRawValue = reflect.TypeOf(RawValue{})
tRaw = reflect.TypeOf(Raw(nil))
)
// registerPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
func registerPrimitiveCodecs(reg *Registry) {
reg.RegisterTypeEncoder(tRawValue, ValueEncoderFunc(rawValueEncodeValue))
reg.RegisterTypeEncoder(tRaw, ValueEncoderFunc(rawEncodeValue))
reg.RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue))
reg.RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue))
}
// rawValueEncodeValue is the ValueEncoderFunc for RawValue.
//
// If the RawValue's Type is "invalid" and the RawValue's Value is not empty or
// nil, then this method will return an error.
func rawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRawValue {
return ValueEncoderError{
Name: "RawValueEncodeValue",
Types: []reflect.Type{tRawValue},
Received: val,
}
}
rawvalue := val.Interface().(RawValue)
if !rawvalue.Type.IsValid() {
return fmt.Errorf("the RawValue Type specifies an invalid BSON type: %#x", byte(rawvalue.Type))
}
return copyValueFromBytes(vw, rawvalue.Type, rawvalue.Value)
}
// rawValueDecodeValue is the ValueDecoderFunc for RawValue.
func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tRawValue {
return ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val}
}
t, value, err := copyValueToBytes(vr)
if err != nil {
return err
}
val.Set(reflect.ValueOf(RawValue{Type: t, Value: value}))
return nil
}
// rawEncodeValue is the ValueEncoderFunc for Reader.
func rawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRaw {
return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val}
}
rdr := val.Interface().(Raw)
return copyDocumentFromBytes(vw, rdr)
}
// rawDecodeValue is the ValueDecoderFunc for Reader.
func rawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tRaw {
return ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
rdr, err := appendDocumentBytes(val.Interface().(Raw), vr)
val.Set(reflect.ValueOf(rdr))
return err
}
+93
View File
@@ -0,0 +1,93 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"io"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// ErrNilReader indicates that an operation was attempted on a nil bson.Reader.
var ErrNilReader = errors.New("nil reader")
// Raw is a raw encoded BSON document. It can be used to delay BSON document decoding or precompute
// a BSON encoded document.
//
// A Raw must be a full BSON document. Use the RawValue type for individual BSON values.
type Raw []byte
// ReadDocument reads a BSON document from the io.Reader and returns it as a bson.Raw. If the
// reader contains multiple BSON documents, only the first document is read.
func ReadDocument(r io.Reader) (Raw, error) {
doc, err := bsoncore.NewDocumentFromReader(r)
return Raw(doc), err
}
// Validate validates the document. This method only validates the first document in
// the slice, to validate other documents, the slice must be resliced.
func (r Raw) Validate() (err error) { return bsoncore.Document(r).Validate() }
// Lookup search the document, potentially recursively, for the given key. If
// there are multiple keys provided, this method will recurse down, as long as
// the top and intermediate nodes are either documents or arrays.If an error
// occurs or if the value doesn't exist, an empty RawValue is returned.
func (r Raw) Lookup(key ...string) RawValue {
return convertFromCoreValue(bsoncore.Document(r).Lookup(key...))
}
// LookupErr searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
func (r Raw) LookupErr(key ...string) (RawValue, error) {
val, err := bsoncore.Document(r).LookupErr(key...)
return convertFromCoreValue(val), err
}
// Elements returns this document as a slice of elements. The returned slice will contain valid
// elements. If the document is not valid, the elements up to the invalid point will be returned
// along with an error.
func (r Raw) Elements() ([]RawElement, error) {
doc := bsoncore.Document(r)
if len(doc) == 0 {
return nil, nil
}
elems, err := doc.Elements()
if err != nil {
return nil, err
}
relems := make([]RawElement, 0, len(elems))
for _, elem := range elems {
relems = append(relems, RawElement(elem))
}
return relems, nil
}
// Values returns this document as a slice of values. The returned slice will contain valid values.
// If the document is not valid, the values up to the invalid point will be returned along with an
// error.
func (r Raw) Values() ([]RawValue, error) {
vals, err := bsoncore.Document(r).Values()
rvals := make([]RawValue, 0, len(vals))
for _, val := range vals {
rvals = append(rvals, convertFromCoreValue(val))
}
return rvals, err
}
// Index searches for and retrieves the element at the given index. This method will panic if
// the document is invalid or if the index is out of bounds.
func (r Raw) Index(index uint) RawElement { return RawElement(bsoncore.Document(r).Index(index)) }
// IndexErr searches for and retrieves the element at the given index.
func (r Raw) IndexErr(index uint) (RawElement, error) {
elem, err := bsoncore.Document(r).IndexErr(index)
return RawElement(elem), err
}
// String returns the BSON document encoded as Extended JSON.
func (r Raw) String() string { return bsoncore.Document(r).String() }
+73
View File
@@ -0,0 +1,73 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"io"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// RawArray is a raw bytes representation of a BSON array.
type RawArray []byte
// ReadArray reads a BSON array from the io.Reader and returns it as a
// bson.RawArray.
func ReadArray(r io.Reader) (RawArray, error) {
doc, err := bsoncore.NewArrayFromReader(r)
return RawArray(doc), err
}
// Index searches for and retrieves the value at the given index. This method
// will panic if the array is invalid or if the index is out of bounds.
func (a RawArray) Index(index uint) RawValue {
return convertFromCoreValue(bsoncore.Array(a).Index(index))
}
// IndexErr searches for and retrieves the value at the given index.
func (a RawArray) IndexErr(index uint) (RawValue, error) {
elem, err := bsoncore.Array(a).IndexErr(index)
return convertFromCoreValue(elem), err
}
// DebugString outputs a human readable version of Array. It will attempt to
// stringify the valid components of the array even if the entire array is not
// valid.
func (a RawArray) DebugString() string {
return bsoncore.Array(a).DebugString()
}
// String outputs an ExtendedJSON version of Array. If the Array is not valid,
// this method returns an empty string.
func (a RawArray) String() string {
return bsoncore.Array(a).String()
}
// Values returns this array as a slice of values. The returned slice will
// contain valid values. If the array is not valid, the values up to the invalid
// point will be returned along with an error.
func (a RawArray) Values() ([]RawValue, error) {
vals, err := bsoncore.Array(a).Values()
if err != nil {
return nil, err
}
rvals := make([]RawValue, 0, len(vals))
for _, val := range vals {
rvals = append(rvals, convertFromCoreValue(val))
}
return rvals, err
}
// Validate validates the array and ensures the elements contained within are
// valid.
func (a RawArray) Validate() error {
return bsoncore.Array(a).Validate()
}
+48
View File
@@ -0,0 +1,48 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// RawElement is a raw encoded BSON document or array element.
type RawElement []byte
// Key returns the key for this element. If the element is not valid, this method returns an empty
// string. If knowing if the element is valid is important, use KeyErr.
func (re RawElement) Key() string { return bsoncore.Element(re).Key() }
// KeyErr returns the key for this element, returning an error if the element is not valid.
func (re RawElement) KeyErr() (string, error) { return bsoncore.Element(re).KeyErr() }
// Value returns the value of this element. If the element is not valid, this method returns an
// empty Value. If knowing if the element is valid is important, use ValueErr.
func (re RawElement) Value() RawValue { return convertFromCoreValue(bsoncore.Element(re).Value()) }
// ValueErr returns the value for this element, returning an error if the element is not valid.
func (re RawElement) ValueErr() (RawValue, error) {
val, err := bsoncore.Element(re).ValueErr()
return convertFromCoreValue(val), err
}
// Validate ensures re is a valid BSON element.
func (re RawElement) Validate() error { return bsoncore.Element(re).Validate() }
// String returns the BSON element encoded as Extended JSON.
func (re RawElement) String() string {
doc := bsoncore.BuildDocument(nil, re)
j, err := MarshalExtJSON(Raw(doc), true, false)
if err != nil {
return "<malformed>"
}
return string(j)
}
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
// valid components of the element even if the entire element is not valid.
func (re RawElement) DebugString() string { return bsoncore.Element(re).DebugString() }
+315
View File
@@ -0,0 +1,315 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// ErrNilContext is returned when the provided DecodeContext is nil.
var ErrNilContext = errors.New("DecodeContext cannot be nil")
// ErrNilRegistry is returned when the provided registry is nil.
var ErrNilRegistry = errors.New("Registry cannot be nil")
// RawValue is a raw encoded BSON value. It can be used to delay BSON value decoding or precompute
// BSON encoded value. Type is the BSON type of the value and Value is the raw encoded BSON value.
//
// A RawValue must be an individual BSON value. Use the Raw type for full BSON documents.
type RawValue struct {
Type Type
Value []byte
r *Registry
}
// IsZero reports whether the RawValue is zero, i.e. no data is present on
// the RawValue. It returns true if Type is 0 and Value is empty or nil.
func (rv RawValue) IsZero() bool {
return rv.Type == 0x00 && len(rv.Value) == 0
}
// Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an
// error is returned. This method will use the registry used to create the RawValue, if the RawValue
// was created from partial BSON processing, or it will use the default registry. Users wishing to
// specify the registry to use should use UnmarshalWithRegistry.
func (rv RawValue) Unmarshal(val any) error {
reg := rv.r
if reg == nil {
reg = defaultRegistry
}
return rv.UnmarshalWithRegistry(reg, val)
}
// Equal compares rv and rv2 and returns true if they are equal.
func (rv RawValue) Equal(rv2 RawValue) bool {
if rv.Type != rv2.Type {
return false
}
if !bytes.Equal(rv.Value, rv2.Value) {
return false
}
return true
}
// UnmarshalWithRegistry performs the same unmarshalling as Unmarshal but uses the provided registry
// instead of the one attached or the default registry.
func (rv RawValue) UnmarshalWithRegistry(r *Registry, val any) error {
if r == nil {
return ErrNilRegistry
}
vr := newBufferedValueReader(rv.Type, rv.Value)
rval := reflect.ValueOf(val)
if rval.Kind() != reflect.Ptr {
return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval)
}
rval = rval.Elem()
dec, err := r.LookupDecoder(rval.Type())
if err != nil {
return err
}
return dec.DecodeValue(DecodeContext{Registry: r}, vr, rval)
}
// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext
// instead of the one attached or the default registry.
func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val any) error {
if dc == nil {
return ErrNilContext
}
vr := newBufferedValueReader(rv.Type, rv.Value)
rval := reflect.ValueOf(val)
if rval.Kind() != reflect.Ptr {
return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval)
}
rval = rval.Elem()
dec, err := dc.LookupDecoder(rval.Type())
if err != nil {
return err
}
return dec.DecodeValue(*dc, vr, rval)
}
func convertFromCoreValue(v bsoncore.Value) RawValue {
return RawValue{Type: Type(v.Type), Value: v.Data}
}
func convertToCoreValue(v RawValue) bsoncore.Value {
return bsoncore.Value{Type: bsoncore.Type(v.Type), Data: v.Value}
}
// Validate ensures the value is a valid BSON value.
func (rv RawValue) Validate() error { return convertToCoreValue(rv).Validate() }
// IsNumber returns true if the type of v is a numeric BSON type.
func (rv RawValue) IsNumber() bool { return convertToCoreValue(rv).IsNumber() }
// String implements the fmt.String interface. This method will return values in extended JSON
// format. If the value is not valid, this returns an empty string
func (rv RawValue) String() string { return convertToCoreValue(rv).String() }
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (rv RawValue) DebugString() string { return convertToCoreValue(rv).DebugString() }
// Double returns the float64 value for this element.
// It panics if e's BSON type is not bson.TypeDouble.
func (rv RawValue) Double() float64 { return convertToCoreValue(rv).Double() }
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
func (rv RawValue) DoubleOK() (float64, bool) { return convertToCoreValue(rv).DoubleOK() }
// StringValue returns the string value for this element.
// It panics if e's BSON type is not bson.TypeString.
//
// NOTE: This method is called StringValue to avoid a collision with the String method which
// implements the fmt.Stringer interface.
func (rv RawValue) StringValue() string { return convertToCoreValue(rv).StringValue() }
// StringValueOK is the same as StringValue, but returns a boolean instead of
// panicking.
func (rv RawValue) StringValueOK() (string, bool) { return convertToCoreValue(rv).StringValueOK() }
// Document returns the BSON document the Value represents as a Document. It panics if the
// value is a BSON type other than document.
func (rv RawValue) Document() Raw { return Raw(convertToCoreValue(rv).Document()) }
// DocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (rv RawValue) DocumentOK() (Raw, bool) {
doc, ok := convertToCoreValue(rv).DocumentOK()
return Raw(doc), ok
}
// Array returns the BSON array the Value represents as an Array. It panics if the
// value is a BSON type other than array.
func (rv RawValue) Array() RawArray { return RawArray(convertToCoreValue(rv).Array()) }
// ArrayOK is the same as Array, except it returns a boolean instead
// of panicking.
func (rv RawValue) ArrayOK() (RawArray, bool) {
doc, ok := convertToCoreValue(rv).ArrayOK()
return RawArray(doc), ok
}
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
// other than binary.
func (rv RawValue) Binary() (subtype byte, data []byte) { return convertToCoreValue(rv).Binary() }
// BinaryOK is the same as Binary, except it returns a boolean instead of
// panicking.
func (rv RawValue) BinaryOK() (subtype byte, data []byte, ok bool) {
return convertToCoreValue(rv).BinaryOK()
}
// ObjectID returns the BSON objectid value the Value represents. It panics if the value is a BSON
// type other than objectid.
func (rv RawValue) ObjectID() ObjectID { return convertToCoreValue(rv).ObjectID() }
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
// panicking.
func (rv RawValue) ObjectIDOK() (ObjectID, bool) {
return convertToCoreValue(rv).ObjectIDOK()
}
// Boolean returns the boolean value the Value represents. It panics if the
// value is a BSON type other than boolean.
func (rv RawValue) Boolean() bool { return convertToCoreValue(rv).Boolean() }
// BooleanOK is the same as Boolean, except it returns a boolean instead of
// panicking.
func (rv RawValue) BooleanOK() (bool, bool) { return convertToCoreValue(rv).BooleanOK() }
// DateTime returns the BSON datetime value the Value represents as a
// unix timestamp. It panics if the value is a BSON type other than datetime.
func (rv RawValue) DateTime() int64 { return convertToCoreValue(rv).DateTime() }
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
// panicking.
func (rv RawValue) DateTimeOK() (int64, bool) { return convertToCoreValue(rv).DateTimeOK() }
// Time returns the BSON datetime value the Value represents. It panics if the value is a BSON
// type other than datetime.
func (rv RawValue) Time() time.Time { return convertToCoreValue(rv).Time() }
// TimeOK is the same as Time, except it returns a boolean instead of
// panicking.
func (rv RawValue) TimeOK() (time.Time, bool) { return convertToCoreValue(rv).TimeOK() }
// Regex returns the BSON regex value the Value represents. It panics if the value is a BSON
// type other than regex.
func (rv RawValue) Regex() (pattern, options string) { return convertToCoreValue(rv).Regex() }
// RegexOK is the same as Regex, except it returns a boolean instead of
// panicking.
func (rv RawValue) RegexOK() (pattern, options string, ok bool) {
return convertToCoreValue(rv).RegexOK()
}
// DBPointer returns the BSON dbpointer value the Value represents. It panics if the value is a BSON
// type other than DBPointer.
func (rv RawValue) DBPointer() (string, ObjectID) {
return convertToCoreValue(rv).DBPointer()
}
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
// instead of panicking.
func (rv RawValue) DBPointerOK() (string, ObjectID, bool) {
return convertToCoreValue(rv).DBPointerOK()
}
// JavaScript returns the BSON JavaScript code value the Value represents. It panics if the value is
// a BSON type other than JavaScript code.
func (rv RawValue) JavaScript() string { return convertToCoreValue(rv).JavaScript() }
// JavaScriptOK is the same as Javascript, excepti that it returns a boolean
// instead of panicking.
func (rv RawValue) JavaScriptOK() (string, bool) { return convertToCoreValue(rv).JavaScriptOK() }
// Symbol returns the BSON symbol value the Value represents. It panics if the value is a BSON
// type other than symbol.
func (rv RawValue) Symbol() string { return convertToCoreValue(rv).Symbol() }
// SymbolOK is the same as Symbol, excepti that it returns a boolean
// instead of panicking.
func (rv RawValue) SymbolOK() (string, bool) { return convertToCoreValue(rv).SymbolOK() }
// CodeWithScope returns the BSON JavaScript code with scope the Value represents.
// It panics if the value is a BSON type other than JavaScript code with scope.
func (rv RawValue) CodeWithScope() (string, Raw) {
code, scope := convertToCoreValue(rv).CodeWithScope()
return code, Raw(scope)
}
// CodeWithScopeOK is the same as CodeWithScope, except that it returns a boolean instead of
// panicking.
func (rv RawValue) CodeWithScopeOK() (string, Raw, bool) {
code, scope, ok := convertToCoreValue(rv).CodeWithScopeOK()
return code, Raw(scope), ok
}
// Int32 returns the int32 the Value represents. It panics if the value is a BSON type other than
// int32.
func (rv RawValue) Int32() int32 { return convertToCoreValue(rv).Int32() }
// Int32OK is the same as Int32, except that it returns a boolean instead of
// panicking.
func (rv RawValue) Int32OK() (int32, bool) { return convertToCoreValue(rv).Int32OK() }
// Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a
// BSON type other than timestamp.
func (rv RawValue) Timestamp() (t, i uint32) { return convertToCoreValue(rv).Timestamp() }
// TimestampOK is the same as Timestamp, except that it returns a boolean
// instead of panicking.
func (rv RawValue) TimestampOK() (t, i uint32, ok bool) { return convertToCoreValue(rv).TimestampOK() }
// Int64 returns the int64 the Value represents. It panics if the value is a BSON type other than
// int64.
func (rv RawValue) Int64() int64 { return convertToCoreValue(rv).Int64() }
// Int64OK is the same as Int64, except that it returns a boolean instead of
// panicking.
func (rv RawValue) Int64OK() (int64, bool) { return convertToCoreValue(rv).Int64OK() }
// AsInt64 returns a BSON number as an int64. If the BSON type is not a numeric one, this method
// will panic.
func (rv RawValue) AsInt64() int64 { return convertToCoreValue(rv).AsInt64() }
// AsInt64OK is the same as AsInt64, except that it returns a boolean instead of
// panicking.
func (rv RawValue) AsInt64OK() (int64, bool) { return convertToCoreValue(rv).AsInt64OK() }
// AsFloat64 returns a BSON number as a float64. If the BSON type is not a numeric one, this method
// will panic.
func (rv RawValue) AsFloat64() float64 { return convertToCoreValue(rv).AsFloat64() }
// AsFloat64OK is the same as AsFloat64, except that it returns a boolean instead of
// panicking.
func (rv RawValue) AsFloat64OK() (float64, bool) { return convertToCoreValue(rv).AsFloat64OK() }
// Decimal128 returns the decimal the Value represents. It panics if the value is a BSON type other than
// decimal.
func (rv RawValue) Decimal128() Decimal128 { return NewDecimal128(convertToCoreValue(rv).Decimal128()) }
// Decimal128OK is the same as Decimal128, except that it returns a boolean
// instead of panicking.
func (rv RawValue) Decimal128OK() (Decimal128, bool) {
h, l, ok := convertToCoreValue(rv).Decimal128OK()
return NewDecimal128(h, l), ok
}
+49
View File
@@ -0,0 +1,49 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
// ArrayReader is implemented by types that allow reading values from a BSON
// array.
type ArrayReader interface {
ReadValue() (ValueReader, error)
}
// DocumentReader is implemented by types that allow reading elements from a
// BSON document.
type DocumentReader interface {
ReadElement() (string, ValueReader, error)
}
// ValueReader is a generic interface used to read values from BSON. This type
// is implemented by several types with different underlying representations of
// BSON, such as a bson.Document, raw BSON bytes, or extended JSON.
type ValueReader interface {
Type() Type
Skip() error
ReadArray() (ArrayReader, error)
ReadBinary() (b []byte, btype byte, err error)
ReadBoolean() (bool, error)
ReadDocument() (DocumentReader, error)
ReadCodeWithScope() (code string, dr DocumentReader, err error)
ReadDBPointer() (ns string, oid ObjectID, err error)
ReadDateTime() (int64, error)
ReadDecimal128() (Decimal128, error)
ReadDouble() (float64, error)
ReadInt32() (int32, error)
ReadInt64() (int64, error)
ReadJavascript() (code string, err error)
ReadMaxKey() error
ReadMinKey() error
ReadNull() error
ReadObjectID() (ObjectID, error)
ReadRegex() (pattern, options string, err error)
ReadString() (string, error)
ReadSymbol() (symbol string, err error)
ReadTimestamp() (t, i uint32, err error)
ReadUndefined() error
}
+381
View File
@@ -0,0 +1,381 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sync"
)
// defaultRegistry is the default Registry. It contains the default codecs and the
// primitive codecs.
var defaultRegistry = NewRegistry()
// errNoEncoder is returned when there wasn't an encoder available for a type.
type errNoEncoder struct {
Type reflect.Type
}
func (ene errNoEncoder) Error() string {
if ene.Type == nil {
return "no encoder found for <nil>"
}
return "no encoder found for " + ene.Type.String()
}
// errNoDecoder is returned when there wasn't a decoder available for a type.
type errNoDecoder struct {
Type reflect.Type
}
func (end errNoDecoder) Error() string {
return "no decoder found for " + end.Type.String()
}
// errNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type.
type errNoTypeMapEntry struct {
Type Type
}
func (entme errNoTypeMapEntry) Error() string {
return "no type map entry found for " + entme.Type.String()
}
// A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type
// documentation for examples of registering various custom encoders and decoders. A Registry can
// have four main types of codecs:
//
// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and
// RegisterTypeDecoder methods. The registered codec will be invoked when encoding/decoding a value
// whose type matches the registered type exactly.
// If the registered type is an interface, the codec will be invoked when encoding or decoding
// values whose type is the interface, but not for values with concrete types that implement the
// interface.
//
// 2. Interface encoders/decoders - These can be registered using the RegisterInterfaceEncoder and
// RegisterInterfaceDecoder methods. These methods only accept interface types and the registered codecs
// will be invoked when encoding or decoding values whose types implement the interface. An example
// of an interface defined by the driver is bson.Marshaler. The driver will call the MarshalBSON method
// for any value whose type implements bson.Marshaler, regardless of the value's concrete type.
//
// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type
// associations are used when decoding into a bson.D/bson.M or a struct field of type any.
// For example, by default, BSON int32 and int64 values decode as Go int32 and int64 instances,
// respectively, when decoding into a bson.D. The following code would change the behavior so these
// values decode as Go int instances instead:
//
// intType := reflect.TypeOf(int(0))
// registry.RegisterTypeMapEntry(bson.TypeInt32, intType).RegisterTypeMapEntry(bson.TypeInt64, intType)
//
// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and
// RegisterDefaultDecoder methods. The registered codec will be invoked when encoding or decoding
// values whose reflect.Kind matches the registered reflect.Kind as long as the value's type doesn't
// match a registered type or interface encoder/decoder first. These methods should be used to change the
// behavior for all values for a specific kind.
//
// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure.
type Registry struct {
interfaceEncoders []interfaceValueEncoder
interfaceDecoders []interfaceValueDecoder
typeEncoders *typeEncoderCache
typeDecoders *typeDecoderCache
kindEncoders *kindEncoderCache
kindDecoders *kindDecoderCache
typeMap sync.Map // map[Type]reflect.Type
}
// NewRegistry creates a new empty Registry.
func NewRegistry() *Registry {
reg := &Registry{
typeEncoders: new(typeEncoderCache),
typeDecoders: new(typeDecoderCache),
kindEncoders: new(kindEncoderCache),
kindDecoders: new(kindDecoderCache),
}
registerDefaultEncoders(reg)
registerDefaultDecoders(reg)
registerPrimitiveCodecs(reg)
return reg
}
// RegisterTypeEncoder registers the provided ValueEncoder for the provided type.
//
// The type will be used as provided, so an encoder can be registered for a type and a different
// encoder can be registered for a pointer to that type.
//
// If the given type is an interface, the encoder will be called when marshaling a type that is
// that interface. It will not be called when marshaling a non-interface type that implements the
// interface. To get the latter behavior, call RegisterHookEncoder instead.
//
// RegisterTypeEncoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) {
r.typeEncoders.Store(valueType, enc)
}
// RegisterTypeDecoder registers the provided ValueDecoder for the provided type.
//
// The type will be used as provided, so a decoder can be registered for a type and a different
// decoder can be registered for a pointer to that type.
//
// If the given type is an interface, the decoder will be called when unmarshaling into a type that
// is that interface. It will not be called when unmarshaling into a non-interface type that
// implements the interface. To get the latter behavior, call RegisterHookDecoder instead.
//
// RegisterTypeDecoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) {
r.typeDecoders.Store(valueType, dec)
}
// RegisterKindEncoder registers the provided ValueEncoder for the provided kind.
//
// Use RegisterKindEncoder to register an encoder for any type with the same underlying kind. For
// example, consider the type MyInt defined as
//
// type MyInt int32
//
// To define an encoder for MyInt and int32, use RegisterKindEncoder like
//
// reg.RegisterKindEncoder(reflect.Int32, myEncoder)
//
// RegisterKindEncoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) {
r.kindEncoders.Store(kind, enc)
}
// RegisterKindDecoder registers the provided ValueDecoder for the provided kind.
//
// Use RegisterKindDecoder to register a decoder for any type with the same underlying kind. For
// example, consider the type MyInt defined as
//
// type MyInt int32
//
// To define an decoder for MyInt and int32, use RegisterKindDecoder like
//
// reg.RegisterKindDecoder(reflect.Int32, myDecoder)
//
// RegisterKindDecoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) {
r.kindDecoders.Store(kind, dec)
}
// RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will
// be called when marshaling a type if the type implements iface or a pointer to the type
// implements iface. If the provided type is not an interface
// (i.e. iface.Kind() != reflect.Interface), this method will panic.
//
// RegisterInterfaceEncoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) {
if iface.Kind() != reflect.Interface {
panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", iface, iface.Kind())
panic(panicStr)
}
for idx, encoder := range r.interfaceEncoders {
if encoder.i == iface {
r.interfaceEncoders[idx].ve = enc
return
}
}
r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{i: iface, ve: enc})
}
// RegisterInterfaceDecoder registers an decoder for the provided interface type iface. This decoder will
// be called when unmarshaling into a type if the type implements iface or a pointer to the type
// implements iface. If the provided type is not an interface (i.e. iface.Kind() != reflect.Interface),
// this method will panic.
//
// RegisterInterfaceDecoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) {
if iface.Kind() != reflect.Interface {
panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", iface, iface.Kind())
panic(panicStr)
}
for idx, decoder := range r.interfaceDecoders {
if decoder.i == iface {
r.interfaceDecoders[idx].vd = dec
return
}
}
r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec})
}
// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this
// mapping is decoding situations where an empty interface is used and a default type needs to be
// created and decoded into.
//
// By default, BSON documents will decode into any values as bson.D. To change the default type for BSON
// documents, a type map entry for TypeEmbeddedDocument should be registered. For example, to force BSON documents
// to decode to bson.Raw, use the following code:
//
// reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{}))
func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) {
r.typeMap.Store(bt, rt)
}
// LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup
// order:
//
// 1. An encoder registered for the exact type. If the given type is an interface, an encoder
// registered using RegisterTypeEncoder for that interface will be selected.
//
// 2. An encoder registered using RegisterInterfaceEncoder for an interface implemented by the type
// or by a pointer to the type. If the value matches multiple interfaces (e.g. the type implements
// bson.Marshaler and bson.ValueMarshaler), the first one registered will be selected.
// Note that registries constructed using bson.NewRegistry have driver-defined interfaces registered
// for the bson.Marshaler, bson.ValueMarshaler, and bson.Proxy interfaces, so those will take
// precedence over any new interfaces.
//
// 3. An encoder registered using RegisterKindEncoder for the kind of value.
//
// If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for
// concurrent use by multiple goroutines after all codecs and encoders are registered.
func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
if valueType == nil {
return nil, errNoEncoder{Type: valueType}
}
enc, found := r.lookupTypeEncoder(valueType)
if found {
if enc == nil {
return nil, errNoEncoder{Type: valueType}
}
return enc, nil
}
enc, found = r.lookupInterfaceEncoder(valueType, true)
if found {
return r.typeEncoders.LoadOrStore(valueType, enc), nil
}
if v, ok := r.kindEncoders.Load(valueType.Kind()); ok {
return r.storeTypeEncoder(valueType, v), nil
}
return nil, errNoEncoder{Type: valueType}
}
func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder {
return r.typeEncoders.LoadOrStore(rt, enc)
}
func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) {
return r.typeEncoders.Load(rt)
}
func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) {
if valueType == nil {
return nil, false
}
for _, ienc := range r.interfaceEncoders {
if valueType.Implements(ienc.i) {
return ienc.ve, true
}
if allowAddr && valueType.Kind() != reflect.Ptr && reflect.PtrTo(valueType).Implements(ienc.i) {
// if *t implements an interface, this will catch if t implements an interface further
// ahead in interfaceEncoders
defaultEnc, found := r.lookupInterfaceEncoder(valueType, false)
if !found {
defaultEnc, _ = r.kindEncoders.Load(valueType.Kind())
}
return newCondAddrEncoder(ienc.ve, defaultEnc), true
}
}
return nil, false
}
// LookupDecoder returns the first matching decoder in the Registry. It uses the following lookup
// order:
//
// 1. A decoder registered for the exact type. If the given type is an interface, a decoder
// registered using RegisterTypeDecoder for that interface will be selected.
//
// 2. A decoder registered using RegisterInterfaceDecoder for an interface implemented by the type or by
// a pointer to the type. If the value matches multiple interfaces (e.g. the type implements
// bson.Unmarshaler and bson.ValueUnmarshaler), the first one registered will be selected.
// Note that registries constructed using bson.NewRegistry have driver-defined interfaces registered
// for the bson.Unmarshaler and bson.ValueUnmarshaler interfaces, so those will take
// precedence over any new interfaces.
//
// 3. A decoder registered using RegisterKindDecoder for the kind of value.
//
// If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for
// concurrent use by multiple goroutines after all codecs and decoders are registered.
func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) {
if valueType == nil {
return nil, errors.New("cannot perform a decoder lookup on <nil>")
}
dec, found := r.lookupTypeDecoder(valueType)
if found {
if dec == nil {
return nil, errNoDecoder{Type: valueType}
}
return dec, nil
}
dec, found = r.lookupInterfaceDecoder(valueType, true)
if found {
return r.storeTypeDecoder(valueType, dec), nil
}
if v, ok := r.kindDecoders.Load(valueType.Kind()); ok {
return r.storeTypeDecoder(valueType, v), nil
}
return nil, errNoDecoder{Type: valueType}
}
func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) {
return r.typeDecoders.Load(valueType)
}
func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder {
return r.typeDecoders.LoadOrStore(typ, dec)
}
func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) {
for _, idec := range r.interfaceDecoders {
if valueType.Implements(idec.i) {
return idec.vd, true
}
if allowAddr && valueType.Kind() != reflect.Ptr && reflect.PtrTo(valueType).Implements(idec.i) {
// if *t implements an interface, this will catch if t implements an interface further
// ahead in interfaceDecoders
defaultDec, found := r.lookupInterfaceDecoder(valueType, false)
if !found {
defaultDec, _ = r.kindDecoders.Load(valueType.Kind())
}
return newCondAddrDecoder(idec.vd, defaultDec), true
}
}
return nil, false
}
// LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON
// type. If no type is found, ErrNoTypeMapEntry is returned.
//
// LookupTypeMapEntry should not be called concurrently with any other Registry method.
func (r *Registry) LookupTypeMapEntry(bt Type) (reflect.Type, error) {
v, ok := r.typeMap.Load(bt)
if v == nil || !ok {
return nil, errNoTypeMapEntry{Type: bt}
}
return v.(reflect.Type), nil
}
type interfaceValueEncoder struct {
i reflect.Type
ve ValueEncoder
}
type interfaceValueDecoder struct {
i reflect.Type
vd ValueDecoder
}
+173
View File
@@ -0,0 +1,173 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
)
// sliceCodec is the Codec used for slice values.
type sliceCodec struct {
// encodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of
// BSON null.
encodeNilAsEmpty bool
}
// EncodeValue is the ValueEncoder for slice types.
func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Slice {
return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
if val.IsNil() && !sc.encodeNilAsEmpty && !ec.nilSliceAsEmpty {
return vw.WriteNull()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
byteSlice := make([]byte, val.Len())
reflect.Copy(reflect.ValueOf(byteSlice), val)
return vw.WriteBinary(byteSlice)
}
// If we have a []E we want to treat it as a document instead of as an array.
if val.Type() == tD || val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(D)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for _, e := range d {
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// DecodeValue is the ValueDecoder for slice types.
func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Slice {
return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
switch vrType := vr.Type(); vrType {
case TypeArray:
case TypeNull:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case TypeUndefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
case Type(0), TypeEmbeddedDocument:
if val.Type().Elem() != tE {
return fmt.Errorf("cannot decode document into %s", val.Type())
}
case TypeBinary:
if val.Type().Elem() != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType)
}
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return fmt.Errorf("SliceDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", TypeBinary, subtype)
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(data)))
}
val.SetLen(0)
val.Set(reflect.AppendSlice(val, reflect.ValueOf(data)))
return nil
case TypeString:
if sliceType := val.Type().Elem(); sliceType != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a string into a byte array, got %v", sliceType)
}
str, err := vr.ReadString()
if err != nil {
return err
}
byteStr := []byte(str)
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(byteStr)))
}
val.SetLen(0)
val.Set(reflect.AppendSlice(val, reflect.ValueOf(byteStr)))
return nil
default:
return fmt.Errorf("cannot decode %v into a slice", vrType)
}
var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error)
switch val.Type().Elem() {
case tE:
elemsFunc = decodeD
default:
elemsFunc = decodeDefault
}
elems, err := elemsFunc(dc, vr, val)
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(elems)))
}
val.SetLen(0)
val.Set(reflect.Append(val, elems...))
return nil
}
+104
View File
@@ -0,0 +1,104 @@
// Copyright (C) MongoDB, Inc. 2025-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bufio"
"io"
)
// streamingByteSrc reads from an ioReader wrapped in a bufio.Reader. It
// first reads the BSON length header, then ensures it only ever reads exactly
// that many bytes.
//
// Note: this approach trades memory usage for extra buffering and reader calls,
// so it is less performanted than the in-memory bufferedValueReader.
type streamingByteSrc struct {
br *bufio.Reader
offset int64 // offset is the current read position in the buffer
}
var _ byteSrc = (*streamingByteSrc)(nil)
// Read reads up to len(p) bytes from the underlying bufio.Reader, advancing
// the offset by the number of bytes read.
func (s *streamingByteSrc) readExact(p []byte) (int, error) {
n, err := io.ReadFull(s.br, p)
if err == nil {
s.offset += int64(n)
}
return n, err
}
// ReadByte returns the single byte at buf[offset] and advances offset by 1.
func (s *streamingByteSrc) ReadByte() (byte, error) {
c, err := s.br.ReadByte()
if err == nil {
s.offset++
}
return c, err
}
// peek returns buf[offset:offset+n] without advancing offset.
func (s *streamingByteSrc) peek(n int) ([]byte, error) {
return s.br.Peek(n)
}
// discard advances offset by n bytes, returning the number of bytes discarded.
func (s *streamingByteSrc) discard(n int) (int, error) {
m, err := s.br.Discard(n)
s.offset += int64(m)
return m, err
}
// readSlice scans buf[offset:] for the first occurrence of delim, returns
// buf[offset:idx+1], and advances offset past it; errors if delim not found.
func (s *streamingByteSrc) readSlice(delim byte) ([]byte, error) {
data, err := s.br.ReadSlice(delim)
if err != nil {
return nil, err
}
s.offset += int64(len(data))
return data, nil
}
// pos returns the current read position in the buffer.
func (s *streamingByteSrc) pos() int64 {
return s.offset
}
// regexLength will return the total byte length of a BSON regex value.
func (s *streamingByteSrc) regexLength() (int32, error) {
var (
count int32
nulCount int
)
for nulCount < 2 {
buf, err := s.br.Peek(int(count) + 1)
if err != nil {
return 0, err
}
b := buf[count]
count++
if b == 0x00 {
nulCount++
}
}
return count, nil
}
func (*streamingByteSrc) streamable() bool {
return true
}
func (s *streamingByteSrc) reset() {
s.offset = 0
}
+107
View File
@@ -0,0 +1,107 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
)
// stringCodec is the Codec used for string values.
type stringCodec struct{}
// Assert that stringCodec satisfies the typeDecoder interface, which allows it to be
// used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &stringCodec{}
// EncodeValue is the ValueEncoder for string types.
func (sc *stringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{
Name: "StringEncodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: val,
}
}
return vw.WriteString(val.String())
}
func (sc *stringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t.Kind() != reflect.String {
return emptyValue, ValueDecoderError{
Name: "StringDecodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: reflect.Zero(t),
}
}
var str string
var err error
switch vr.Type() {
case TypeString:
str, err = vr.ReadString()
if err != nil {
return emptyValue, err
}
case TypeObjectID:
if dc.objectIDAsHexString {
oid, err := vr.ReadObjectID()
if err != nil {
return emptyValue, err
}
str = oid.Hex()
} else {
const msg = "decoding an object ID into a string is not supported by default " +
"(set Decoder.ObjectIDAsHexString to enable decoding as a hexadecimal string)"
return emptyValue, errors.New(msg)
}
case TypeSymbol:
str, err = vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
case TypeBinary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "string"}
}
str = string(data)
case TypeNull:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a string type", vr.Type())
}
return reflect.ValueOf(str), nil
}
// DecodeValue is the ValueDecoder for string types.
func (sc *stringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.String {
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
elem, err := sc.decodeType(dctx, vr, val.Type())
if err != nil {
return err
}
val.SetString(elem.String())
return nil
}
+695
View File
@@ -0,0 +1,695 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
)
// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
type DecodeError struct {
keys []string
wrapped error
}
// Unwrap returns the underlying error
func (de *DecodeError) Unwrap() error {
return de.wrapped
}
// Error implements the error interface.
func (de *DecodeError) Error() string {
// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
// stack of BSON keys, so we call de.Keys(), which reverses them.
keyPath := strings.Join(de.Keys(), ".")
return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
}
// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
// a string, the keys slice will be ["a", "b", "c"].
func (de *DecodeError) Keys() []string {
reversedKeys := make([]string, 0, len(de.keys))
for idx := len(de.keys) - 1; idx >= 0; idx-- {
reversedKeys = append(reversedKeys, de.keys[idx])
}
return reversedKeys
}
// mapElementsEncoder handles encoding of the values of an inline map.
type mapElementsEncoder interface {
encodeMapElements(EncodeContext, DocumentWriter, reflect.Value, func(string) bool) error
}
// structCodec is the Codec used for struct values.
type structCodec struct {
cache sync.Map // map[reflect.Type]*structDescription
inlineMapEncoder mapElementsEncoder
// decodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the
// destination value passed to Decode before unmarshaling BSON documents into them.
decodeZeroStruct bool
// decodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the
// destination value passed to Decode before unmarshaling BSON documents into them.
decodeDeepZeroInline bool
// encodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g.
// MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag
// option is set.
encodeOmitDefaultStruct bool
// allowUnexportedFields allows encoding and decoding values from un-exported struct fields.
allowUnexportedFields bool
// overwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is
// a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The
// default value is true.
overwriteDuplicatedInlinedFields bool
}
var (
_ ValueEncoder = &structCodec{}
_ ValueDecoder = &structCodec{}
)
// newStructCodec returns a StructCodec that uses p for struct tag parsing.
func newStructCodec(elemEncoder mapElementsEncoder) *structCodec {
return &structCodec{
inlineMapEncoder: elemEncoder,
overwriteDuplicatedInlinedFields: true,
}
}
// EncodeValue handles encoding generic struct types.
func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Struct {
return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates)
if err != nil {
return err
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
var rv reflect.Value
for _, desc := range sd.fl {
if desc.inline == nil {
rv = val.Field(desc.idx)
} else {
rv, err = fieldByIndexErr(val, desc.inline)
if err != nil {
continue
}
}
if ec.omitEmpty {
desc.omitEmpty = true
}
desc.encoder, rv, err = lookupElementEncoder(ec, desc.encoder, rv)
if err != nil && !errors.Is(err, errInvalidValue) {
return err
}
if errors.Is(err, errInvalidValue) {
if desc.omitEmpty {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
err = vw2.WriteNull()
if err != nil {
return err
}
continue
}
if desc.encoder == nil {
return errNoEncoder{Type: rv.Type()}
}
encoder := desc.encoder
var empty bool
if rv.Kind() == reflect.Interface {
// isEmpty will not treat an interface rv as an interface, so we need to check for the
// nil interface separately.
empty = rv.IsNil()
} else {
empty = isEmpty(rv, sc.encodeOmitDefaultStruct || ec.omitZeroStruct)
}
if desc.omitEmpty && empty {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
ectx := EncodeContext{
Registry: ec.Registry,
minSize: desc.minSize || ec.minSize,
errorOnInlineDuplicates: ec.errorOnInlineDuplicates,
stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt,
nilMapAsEmpty: ec.nilMapAsEmpty,
nilSliceAsEmpty: ec.nilSliceAsEmpty,
nilByteSliceAsEmpty: ec.nilByteSliceAsEmpty,
omitZeroStruct: ec.omitZeroStruct,
useJSONStructTags: ec.useJSONStructTags,
}
err = encoder.EncodeValue(ectx, vw2, rv)
if err != nil {
return err
}
}
if sd.inlineMap >= 0 {
rv := val.Field(sd.inlineMap)
collisionFn := func(key string) bool {
_, exists := sd.fm[key]
return exists
}
err = sc.inlineMapEncoder.encodeMapElements(ec, dw, rv, collisionFn)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
func newDecodeError(key string, original error) error {
var de *DecodeError
if !errors.As(original, &de) {
return &DecodeError{
keys: []string{key},
wrapped: original,
}
}
de.keys = append(de.keys, key)
return de
}
// DecodeValue implements the Codec interface.
// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Struct {
return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
switch vrType := vr.Type(); vrType {
case Type(0), TypeEmbeddedDocument:
case TypeNull:
if err := vr.ReadNull(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
case TypeUndefined:
if err := vr.ReadUndefined(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false)
if err != nil {
return err
}
if sc.decodeZeroStruct || dc.zeroStructs {
val.Set(reflect.Zero(val.Type()))
}
if sc.decodeDeepZeroInline && sd.inline {
val.Set(deepZero(val.Type()))
}
var decoder ValueDecoder
var inlineMap reflect.Value
if sd.inlineMap >= 0 {
inlineMap = val.Field(sd.inlineMap)
decoder, err = dc.LookupDecoder(inlineMap.Type().Elem())
if err != nil {
return err
}
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
for {
name, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return err
}
fd, exists := sd.fm[name]
if !exists {
// if the original name isn't found in the struct description, try again with the name in lowercase
// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
// names
fd, exists = sd.fm[strings.ToLower(name)]
}
if !exists {
if sd.inlineMap < 0 {
// The encoding/json package requires a flag to return on error for non-existent fields.
// This functionality seems appropriate for the struct codec.
err = vr.Skip()
if err != nil {
return err
}
continue
}
if inlineMap.IsNil() {
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
}
elem := reflect.New(inlineMap.Type().Elem()).Elem()
err = decoder.DecodeValue(dc, vr, elem)
if err != nil {
return err
}
inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
continue
}
var field reflect.Value
if fd.inline == nil {
field = val.Field(fd.idx)
} else {
field, err = getInlineField(val, fd.inline)
if err != nil {
return err
}
}
if field.Kind() == reflect.Interface && !field.IsNil() && field.Elem().Kind() == reflect.Ptr {
v := field.Elem().Elem()
decoder, err = dc.LookupDecoder(v.Type())
if err != nil {
return err
}
err = decoder.DecodeValue(dc, vr, v)
if err != nil {
return newDecodeError(fd.name, err)
}
continue
}
if !field.CanSet() { // Being settable is a super set of being addressable.
innerErr := fmt.Errorf("field %v is not settable", field)
return newDecodeError(fd.name, innerErr)
}
if field.Kind() == reflect.Ptr && field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Addr()
dctx := DecodeContext{
Registry: dc.Registry,
truncate: fd.truncate || dc.truncate,
defaultDocumentType: dc.defaultDocumentType,
binaryAsSlice: dc.binaryAsSlice,
objectIDAsHexString: dc.objectIDAsHexString,
useJSONStructTags: dc.useJSONStructTags,
useLocalTimeZone: dc.useLocalTimeZone,
zeroMaps: dc.zeroMaps,
zeroStructs: dc.zeroStructs,
}
if fd.decoder == nil {
return newDecodeError(fd.name, errNoDecoder{Type: field.Elem().Type()})
}
err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
if err != nil {
return newDecodeError(fd.name, err)
}
}
return nil
}
func isEmpty(v reflect.Value, omitZeroStruct bool) bool {
kind := v.Kind()
if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) {
return v.Interface().(Zeroer).IsZero()
}
switch kind {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Struct:
if !omitZeroStruct {
return false
}
vt := v.Type()
if vt == tTime {
return v.Interface().(time.Time).IsZero()
}
numField := vt.NumField()
for i := 0; i < numField; i++ {
ff := vt.Field(i)
if ff.PkgPath != "" && !ff.Anonymous {
continue // Private field
}
if !isEmpty(v.Field(i), omitZeroStruct) {
return false
}
}
return true
}
return !v.IsValid() || v.IsZero()
}
type structDescription struct {
fm map[string]fieldDescription
fl []fieldDescription
inlineMap int
inline bool
}
type fieldDescription struct {
name string // BSON key name
fieldName string // struct field name
idx int
omitEmpty bool
minSize bool
truncate bool
inline []int
encoder ValueEncoder
decoder ValueDecoder
}
type byIndex []fieldDescription
func (bi byIndex) Len() int { return len(bi) }
func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
func (bi byIndex) Less(i, j int) bool {
// If a field is inlined, its index in the top level struct is stored at inline[0]
iIdx, jIdx := bi[i].idx, bi[j].idx
if len(bi[i].inline) > 0 {
iIdx = bi[i].inline[0]
}
if len(bi[j].inline) > 0 {
jIdx = bi[j].inline[0]
}
if iIdx != jIdx {
return iIdx < jIdx
}
for k, biik := range bi[i].inline {
if k >= len(bi[j].inline) {
return false
}
if biik != bi[j].inline[k] {
return biik < bi[j].inline[k]
}
}
return len(bi[i].inline) < len(bi[j].inline)
}
func (sc *structCodec) describeStruct(
r *Registry,
t reflect.Type,
useJSONStructTags bool,
errorOnDuplicates bool,
) (*structDescription, error) {
// We need to analyze the struct, including getting the tags, collecting
// information about inlining, and create a map of the field name to the field.
if v, ok := sc.cache.Load(t); ok {
return v.(*structDescription), nil
}
// TODO(charlie): Only describe the struct once when called
// concurrently with the same type.
ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates)
if err != nil {
return nil, err
}
if v, loaded := sc.cache.LoadOrStore(t, ds); loaded {
ds = v.(*structDescription)
}
return ds, nil
}
func (sc *structCodec) describeStructSlow(
r *Registry,
t reflect.Type,
useJSONStructTags bool,
errorOnDuplicates bool,
) (*structDescription, error) {
numFields := t.NumField()
sd := &structDescription{
fm: make(map[string]fieldDescription, numFields),
fl: make([]fieldDescription, 0, numFields),
inlineMap: -1,
}
var fields []fieldDescription
for i := 0; i < numFields; i++ {
sf := t.Field(i)
if sf.PkgPath != "" && (!sc.allowUnexportedFields || !sf.Anonymous) {
// field is private or unexported fields aren't allowed, ignore
continue
}
sfType := sf.Type
encoder, err := r.LookupEncoder(sfType)
if err != nil {
encoder = nil
}
decoder, err := r.LookupDecoder(sfType)
if err != nil {
decoder = nil
}
description := fieldDescription{
fieldName: sf.Name,
idx: i,
encoder: encoder,
decoder: decoder,
}
var stags *structTags
// If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser
// instead of the parser defined on the codec.
if useJSONStructTags {
stags, err = parseJSONStructTags(sf)
} else {
stags, err = parseStructTags(sf)
}
if err != nil {
return nil, err
}
if stags.Skip {
continue
}
description.name = stags.Name
description.omitEmpty = stags.OmitEmpty
description.minSize = stags.MinSize
description.truncate = stags.Truncate
if stags.Inline {
sd.inline = true
switch sfType.Kind() {
case reflect.Map:
if sd.inlineMap >= 0 {
return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
}
if sfType.Key() != tString {
return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
}
sd.inlineMap = description.idx
case reflect.Ptr:
sfType = sfType.Elem()
if sfType.Kind() != reflect.Struct {
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
fallthrough
case reflect.Struct:
inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates)
if err != nil {
return nil, err
}
for _, fd := range inlinesf.fl {
if fd.inline == nil {
fd.inline = []int{i, fd.idx}
} else {
fd.inline = append([]int{i}, fd.inline...)
}
fields = append(fields, fd)
}
default:
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
continue
}
fields = append(fields, description)
}
// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
sort.Slice(fields, func(i, j int) bool {
x := fields
// sort field by name, breaking ties with depth, then
// breaking ties with index sequence.
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].inline) != len(x[j].inline) {
return len(x[i].inline) < len(x[j].inline)
}
return byIndex(x).Less(i, j)
})
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
sd.fl = append(sd.fl, fi)
sd.fm[name] = fi
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if !ok || !sc.overwriteDuplicatedInlinedFields || errorOnDuplicates {
return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
}
sd.fl = append(sd.fl, dominant)
sd.fm[name] = dominant
}
sort.Sort(byIndex(sd.fl))
return sd, nil
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's inlining rules. If there are multiple top-level
// fields, the boolean will be false: This condition is an error in Go
// and we skip all the fields.
func dominantField(fields []fieldDescription) (fieldDescription, bool) {
// The fields are sorted in increasing index-length order, then by presence of tag.
// That means that the first field is the dominant one. We need only check
// for error cases: two fields at top level.
if len(fields) > 1 &&
len(fields[0].inline) == len(fields[1].inline) {
return fieldDescription{}, false
}
return fields[0], true
}
func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
defer func() {
if recovered := recover(); recovered != nil {
switch r := recovered.(type) {
case string:
err = fmt.Errorf("%s", r)
case error:
err = r
}
}
}()
result = v.FieldByIndex(index)
return
}
func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
field, err := fieldByIndexErr(val, index)
if err == nil {
return field, nil
}
// if parent of this element doesn't exist, fix its parent
inlineParent := index[:len(index)-1]
var fParent reflect.Value
if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
fParent, err = getInlineField(val, inlineParent)
if err != nil {
return fParent, err
}
}
fParent.Set(reflect.New(fParent.Type().Elem()))
return fieldByIndexErr(val, index)
}
// DeepZero returns recursive zero object
func deepZero(st reflect.Type) (result reflect.Value) {
if st.Kind() == reflect.Struct {
numField := st.NumField()
for i := 0; i < numField; i++ {
if result == emptyValue {
result = reflect.Indirect(reflect.New(st))
}
f := result.Field(i)
if f.CanInterface() {
if f.Type().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
}
}
}
}
return result
}
// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
func recursivePointerTo(v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
result := reflect.New(v.Type())
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
if f := v.Field(i); f.Kind() == reflect.Ptr {
if f.Elem().Kind() == reflect.Struct {
result.Elem().Field(i).Set(recursivePointerTo(f))
}
}
}
}
return result
}
+123
View File
@@ -0,0 +1,123 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"strings"
)
// structTags represents the struct tag fields that the StructCodec uses during
// the encoding and decoding process.
//
// In the case of a struct, the lowercased field name is used as the key for each exported
// field but this behavior may be changed using a struct tag. The tag may also contain flags to
// adjust the marshalling behavior for the field.
//
// The properties are defined below:
//
// OmitEmpty Only include the field if it's not set to the zero value for the type or to
// empty slices or maps.
//
// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's
// feasible while preserving the numeric value.
//
// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within
// a float32.
//
// Inline Inline the field, which must be a struct or a map, causing all of its fields
// or keys to be processed as if they were part of the outer struct. For maps,
// keys must not conflict with the bson keys of other struct fields.
//
// Skip This struct field should be skipped. This is usually denoted by parsing a "-"
// for the name.
type structTags struct {
Name string
OmitEmpty bool
MinSize bool
Truncate bool
Inline bool
Skip bool
}
// DefaultStructTagParser is the StructTagParser used by the StructCodec by default.
// It will handle the bson struct tag. See the documentation for StructTags to see
// what each of the returned fields means.
//
// If there is no name in the struct tag fields, the struct field name is lowercased.
// The tag formats accepted are:
//
// "[<key>][,<flag1>[,<flag2>]]"
//
// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)`
//
// An example:
//
// type T struct {
// A bool
// B int "myb"
// C string "myc,omitempty"
// D string `bson:",omitempty" json:"jsonkey"`
// E int64 ",minsize"
// F int64 "myf,omitempty,minsize"
// }
//
// A struct tag either consisting entirely of '-' or with a bson key with a
// value consisting entirely of '-' will return a StructTags with Skip true and
// the remaining fields will be their default values.
func parseStructTags(sf reflect.StructField) (*structTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}
// jsonStructTagParser has the same behavior as DefaultStructTagParser
// but will also fallback to parsing the json tag instead on a field where the
// bson tag isn't available.
func parseJSONStructTags(sf reflect.StructField) (*structTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok {
tag, ok = sf.Tag.Lookup("json")
}
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}
func parseTags(key string, tag string) (*structTags, error) {
var st structTags
if tag == "-" {
st.Skip = true
return &st, nil
}
for idx, str := range strings.Split(tag, ",") {
if idx == 0 && str != "" {
key = str
}
switch str {
case "omitempty":
st.OmitEmpty = true
case "minsize":
st.MinSize = true
case "truncate":
st.Truncate = true
case "inline":
st.Inline = true
}
}
st.Name = key
return &st, nil
}
+109
View File
@@ -0,0 +1,109 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
"time"
)
const (
timeFormatString = "2006-01-02T15:04:05.999Z07:00"
)
// timeCodec is the Codec used for time.Time values.
type timeCodec struct {
// useLocalTimeZone specifies if we should decode into the local time zone. Defaults to false.
useLocalTimeZone bool
}
// Assert that timeCodec satisfies the typeDecoder interface, which allows it to be used
// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection.
var _ typeDecoder = &timeCodec{}
func (tc *timeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tTime {
return emptyValue, ValueDecoderError{
Name: "TimeDecodeValue",
Types: []reflect.Type{tTime},
Received: reflect.Zero(t),
}
}
var timeVal time.Time
switch vrType := vr.Type(); vrType {
case TypeDateTime:
dt, err := vr.ReadDateTime()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(dt/1000, dt%1000*1000000)
case TypeString:
// assume strings are in the isoTimeFormat
timeStr, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
timeVal, err = time.Parse(timeFormatString, timeStr)
if err != nil {
return emptyValue, err
}
case TypeInt64:
i64, err := vr.ReadInt64()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(i64/1000, i64%1000*1000000)
case TypeTimestamp:
t, _, err := vr.ReadTimestamp()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(int64(t), 0)
case TypeNull:
if err := vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err := vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType)
}
if !tc.useLocalTimeZone && !dc.useLocalTimeZone {
timeVal = timeVal.UTC()
}
return reflect.ValueOf(timeVal), nil
}
// DecodeValue is the ValueDecoderFunc for time.Time.
func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tTime {
return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val}
}
elem, err := tc.decodeType(dc, vr, tTime)
if err != nil {
return err
}
val.Set(elem)
return nil
}
// EncodeValue is the ValueEncoderFunc for time.Time.
func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTime {
return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val}
}
tt := val.Interface().(time.Time)
dt := NewDateTimeFromTime(tt)
return vw.WriteDateTime(int64(dt))
}
+128
View File
@@ -0,0 +1,128 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/json"
"net/url"
"reflect"
"time"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// Type represents a BSON type.
type Type byte
// String returns the string representation of the BSON type's name.
func (bt Type) String() string {
return bsoncore.Type(bt).String()
}
// IsValid will return true if the Type is valid.
func (bt Type) IsValid() bool {
switch bt {
case TypeDouble, TypeString, TypeEmbeddedDocument, TypeArray, TypeBinary,
TypeUndefined, TypeObjectID, TypeBoolean, TypeDateTime, TypeNull, TypeRegex,
TypeDBPointer, TypeJavaScript, TypeSymbol, TypeCodeWithScope, TypeInt32,
TypeTimestamp, TypeInt64, TypeDecimal128, TypeMinKey, TypeMaxKey:
return true
default:
return false
}
}
// BSON element types as described in https://bsonspec.org/spec.html.
const (
TypeDouble Type = 0x01
TypeString Type = 0x02
TypeEmbeddedDocument Type = 0x03
TypeArray Type = 0x04
TypeBinary Type = 0x05
TypeUndefined Type = 0x06
TypeObjectID Type = 0x07
TypeBoolean Type = 0x08
TypeDateTime Type = 0x09
TypeNull Type = 0x0A
TypeRegex Type = 0x0B
TypeDBPointer Type = 0x0C
TypeJavaScript Type = 0x0D
TypeSymbol Type = 0x0E
TypeCodeWithScope Type = 0x0F
TypeInt32 Type = 0x10
TypeTimestamp Type = 0x11
TypeInt64 Type = 0x12
TypeDecimal128 Type = 0x13
TypeMaxKey Type = 0x7F
TypeMinKey Type = 0xFF
)
// BSON binary element subtypes as described in https://bsonspec.org/spec.html.
const (
TypeBinaryGeneric byte = 0x00
TypeBinaryFunction byte = 0x01
TypeBinaryBinaryOld byte = 0x02
TypeBinaryUUIDOld byte = 0x03
TypeBinaryUUID byte = 0x04
TypeBinaryMD5 byte = 0x05
TypeBinaryEncrypted byte = 0x06
TypeBinaryColumn byte = 0x07
TypeBinarySensitive byte = 0x08
TypeBinaryVector byte = 0x09
TypeBinaryUserDefined byte = 0x80
)
var (
tBool = reflect.TypeOf(false)
tFloat64 = reflect.TypeOf(float64(0))
tInt32 = reflect.TypeOf(int32(0))
tInt64 = reflect.TypeOf(int64(0))
tString = reflect.TypeOf("")
tTime = reflect.TypeOf(time.Time{})
)
var (
tEmpty = reflect.TypeOf((*any)(nil)).Elem()
tByteSlice = reflect.TypeOf([]byte(nil))
tByte = reflect.TypeOf(byte(0x00))
tURL = reflect.TypeOf(url.URL{})
tJSONNumber = reflect.TypeOf(json.Number(""))
)
var (
tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem()
tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem()
)
var (
tBinary = reflect.TypeOf(Binary{})
tUndefined = reflect.TypeOf(Undefined{})
tOID = reflect.TypeOf(ObjectID{})
tDateTime = reflect.TypeOf(DateTime(0))
tNull = reflect.TypeOf(Null{})
tRegex = reflect.TypeOf(Regex{})
tCodeWithScope = reflect.TypeOf(CodeWithScope{})
tDBPointer = reflect.TypeOf(DBPointer{})
tJavaScript = reflect.TypeOf(JavaScript(""))
tSymbol = reflect.TypeOf(Symbol(""))
tTimestamp = reflect.TypeOf(Timestamp{})
tDecimal = reflect.TypeOf(Decimal128{})
tVector = reflect.TypeOf(Vector{})
tMinKey = reflect.TypeOf(MinKey{})
tMaxKey = reflect.TypeOf(MaxKey{})
tD = reflect.TypeOf(D{})
tA = reflect.TypeOf(A{})
tE = reflect.TypeOf(E{})
)
var (
tCoreDocument = reflect.TypeOf(bsoncore.Document{})
tCoreArray = reflect.TypeOf(bsoncore.Array{})
)
+161
View File
@@ -0,0 +1,161 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"math"
"reflect"
)
// uintCodec is the Codec used for uint values.
type uintCodec struct {
// encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the
// minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value.
encodeToMinSize bool
}
// Assert that uintCodec satisfies the typeDecoder interface, which allows it to be used
// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection.
var _ typeDecoder = &uintCodec{}
// EncodeValue is the ValueEncoder for uint types.
func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Uint8, reflect.Uint16:
return vw.WriteInt32(int32(val.Uint()))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
u64 := val.Uint()
// If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32
useMinSize := ec.minSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64)
if u64 <= math.MaxInt32 && useMinSize {
return vw.WriteInt32(int32(u64))
}
if u64 > math.MaxInt64 {
return fmt.Errorf("%d overflows int64", u64)
}
return vw.WriteInt64(int64(u64))
}
return ValueEncoderError{
Name: "UintEncodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
var i64 int64
var err error
switch vrType := vr.Type(); vrType {
case TypeInt32:
i32, err := vr.ReadInt32()
if err != nil {
return emptyValue, err
}
i64 = int64(i32)
case TypeInt64:
i64, err = vr.ReadInt64()
if err != nil {
return emptyValue, err
}
case TypeDouble:
f64, err := vr.ReadDouble()
if err != nil {
return emptyValue, err
}
if !dc.truncate && math.Floor(f64) != f64 {
return emptyValue, errCannotTruncate
}
if f64 > float64(math.MaxInt64) {
return emptyValue, fmt.Errorf("%g overflows int64", f64)
}
i64 = int64(f64)
case TypeBoolean:
b, err := vr.ReadBoolean()
if err != nil {
return emptyValue, err
}
if b {
i64 = 1
}
case TypeNull:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType)
}
switch t.Kind() {
case reflect.Uint8:
if i64 < 0 || i64 > math.MaxUint8 {
return emptyValue, fmt.Errorf("%d overflows uint8", i64)
}
return reflect.ValueOf(uint8(i64)), nil
case reflect.Uint16:
if i64 < 0 || i64 > math.MaxUint16 {
return emptyValue, fmt.Errorf("%d overflows uint16", i64)
}
return reflect.ValueOf(uint16(i64)), nil
case reflect.Uint32:
if i64 < 0 || i64 > math.MaxUint32 {
return emptyValue, fmt.Errorf("%d overflows uint32", i64)
}
return reflect.ValueOf(uint32(i64)), nil
case reflect.Uint64:
if i64 < 0 {
return emptyValue, fmt.Errorf("%d overflows uint64", i64)
}
return reflect.ValueOf(uint64(i64)), nil
case reflect.Uint:
if i64 < 0 {
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}
v := uint64(i64)
if v > math.MaxUint { // Can we fit this inside of an uint
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}
return reflect.ValueOf(uint(v)), nil
default:
return emptyValue, ValueDecoderError{
Name: "UintDecodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: reflect.Zero(t),
}
}
}
// DecodeValue is the ValueDecoder for uint types.
func (uic *uintCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() {
return ValueDecoderError{
Name: "UintDecodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
elem, err := uic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.SetUint(elem.Uint())
return nil
}
+90
View File
@@ -0,0 +1,90 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"fmt"
)
// Unmarshaler is the interface implemented by types that can unmarshal a BSON
// document representation of themselves. The input can be assumed to be a valid
// encoding of a BSON document. UnmarshalBSON must copy the JSON data if it
// wishes to retain the data after returning.
//
// Unmarshaler is only used to unmarshal full BSON documents. To create custom
// BSON unmarshaling behavior for individual values in a BSON document,
// implement the ValueUnmarshaler interface instead.
type Unmarshaler interface {
UnmarshalBSON([]byte) error
}
// ValueUnmarshaler is the interface implemented by types that can unmarshal a
// BSON value representation of themselves. The input can be assumed to be a
// valid encoding of a BSON value. UnmarshalBSONValue must copy the BSON value
// bytes if it wishes to retain the data after returning.
//
// ValueUnmarshaler is only used to unmarshal individual values in a BSON
// document. To create custom BSON unmarshaling behavior for an entire BSON
// document, implement the Unmarshaler interface instead.
type ValueUnmarshaler interface {
UnmarshalBSONValue(typ byte, data []byte) error
}
// Unmarshal parses the BSON-encoded data and stores the result in the value
// pointed to by val. If val is nil or not a pointer, Unmarshal returns an
// error.
//
// When unmarshaling BSON, if the BSON value is null and the Go value is a
// pointer, the pointer is set to nil without calling UnmarshalBSONValue.
func Unmarshal(data []byte, val any) error {
vr := getBufferedDocumentReader(data)
defer putBufferedDocumentReader(vr)
if l, err := vr.peekLength(); err != nil {
return err
} else if int(l) != len(data) {
return fmt.Errorf("invalid document length")
}
return unmarshalFromReader(DecodeContext{Registry: defaultRegistry}, vr, val)
}
// UnmarshalValue parses the BSON value of type t with bson.NewRegistry() and
// stores the result in the value pointed to by val. If val is nil or not a pointer,
// UnmarshalValue returns an error.
func UnmarshalValue(t Type, data []byte, val any) error {
vr := newBufferedValueReader(t, data)
return unmarshalFromReader(DecodeContext{Registry: defaultRegistry}, vr, val)
}
// UnmarshalExtJSON parses the extended JSON-encoded data and stores the result
// in the value pointed to by val. If val is nil or not a pointer, UnmarshalExtJSON
// returns an error.
//
// If canonicalOnly is true, UnmarshalExtJSON returns an error if the Extended
// JSON was not marshaled in canonical mode.
//
// For more information about Extended JSON, see
// https://www.mongodb.com/docs/manual/reference/mongodb-extended-json/
func UnmarshalExtJSON(data []byte, canonicalOnly bool, val any) error {
ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonicalOnly)
if err != nil {
return err
}
return unmarshalFromReader(DecodeContext{Registry: defaultRegistry}, ejvr, val)
}
func unmarshalFromReader(dc DecodeContext, vr ValueReader, val any) error {
dec := decPool.Get().(*Decoder)
defer decPool.Put(dec)
dec.Reset(vr)
dec.dc = dc
return dec.Decode(val)
}
+960
View File
@@ -0,0 +1,960 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"sync"
)
type byteSrc interface {
io.ByteReader
readExact(p []byte) (int, error)
// Peek returns the next n bytes without advancing the cursor. It must return
// exactly n bytes or n error if fewer are available.
peek(n int) ([]byte, error)
// discard advanced the cursor by n bytes, returning the actual number
// discarded or an error if fewer were available.
discard(n int) (int, error)
// readSlice reads until (and including) the first occurrence of delim,
// returning the entire slice [start...delimiter] and advancing the cursor.
// past it. Returns an error if delim is not found.
readSlice(delim byte) ([]byte, error)
// pos returns the number of bytes consumed so far.
pos() int64
// regexLength returns the total byte length of a BSON regex value (two
// C-strings including their terminating NULs) in buffered mode.
regexLength() (int32, error)
// streamable returns true if this source can be used in a streaming context.
streamable() bool
// reset resets the source to its initial state.
reset()
}
var _ ValueReader = &valueReader{}
// ErrEOA is the error returned when the end of a BSON array has been reached.
var ErrEOA = errors.New("end of array")
// ErrEOD is the error returned when the end of a BSON document has been reached.
var ErrEOD = errors.New("end of document")
type vrState struct {
mode mode
vType Type
end int64
}
var vrPool = sync.Pool{
New: func() any {
return &valueReader{
stack: make([]vrState, 1, 5),
}
},
}
// valueReader is for reading BSON values.
type valueReader struct {
src byteSrc
offset int64
stack []vrState
frame int64
}
func getBufferedDocumentReader(b []byte) *valueReader {
return newBufferedDocumentReader(b)
}
func putBufferedDocumentReader(vr *valueReader) {
if vr == nil {
return
}
vr.src.reset()
// Reset src and stack to avoid holding onto memory.
vr.src = nil
vr.frame = 0
vr.stack = vr.stack[:0]
vrPool.Put(vr)
}
// NewDocumentReader returns a ValueReader using b for the underlying BSON
// representation.
func NewDocumentReader(r io.Reader) ValueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mTopLevel,
}
return &valueReader{
src: &streamingByteSrc{br: bufio.NewReader(r), offset: 0},
stack: stack,
}
}
// newBufferedValueReader returns a ValueReader that starts in the Value mode
// instead of in top level document mode. This enables the creation of a
// ValueReader for a single BSON value.
func newBufferedValueReader(t Type, b []byte) ValueReader {
bVR := newBufferedDocumentReader(b)
bVR.stack[0].vType = t
bVR.stack[0].mode = mValue
return bVR
}
func newBufferedDocumentReader(b []byte) *valueReader {
vr := vrPool.Get().(*valueReader)
vr.src = &bufferedByteSrc{
buf: b,
offset: 0,
}
// Reset parse state.
vr.frame = 0
if cap(vr.stack) < 1 {
vr.stack = make([]vrState, 1, 5)
} else {
vr.stack = vr.stack[:1]
}
vr.stack[0] = vrState{
mode: mTopLevel,
end: int64(len(b)),
}
return vr
}
func (vr *valueReader) advanceFrame() {
if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack
length := len(vr.stack)
if length+1 >= cap(vr.stack) {
// double it
buf := make([]vrState, 2*cap(vr.stack)+1)
copy(buf, vr.stack)
vr.stack = buf
}
vr.stack = vr.stack[:length+1]
}
vr.frame++
// Clean the stack
vr.stack[vr.frame].mode = 0
vr.stack[vr.frame].vType = 0
vr.stack[vr.frame].end = 0
}
func (vr *valueReader) pop() error {
var cnt int
switch vr.stack[vr.frame].mode {
case mElement, mValue:
cnt = 1
case mDocument, mArray, mCodeWithScope:
cnt = 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
for i := 0; i < cnt && vr.frame > 0; i++ {
if vr.src.pos() < vr.stack[vr.frame].end {
_, err := vr.src.discard(int(vr.stack[vr.frame].end - vr.src.pos()))
if err != nil {
return err
}
}
vr.frame--
}
if vr.src.streamable() {
if vr.frame == 0 {
if vr.stack[0].end > vr.src.pos() {
vr.stack[0].end -= vr.src.pos()
} else {
vr.stack[0].end = 0
}
vr.src.reset()
}
}
return nil
}
func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vr.stack[vr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if vr.frame != 0 {
te.parent = vr.stack[vr.frame-1].mode
}
return te
}
func (vr *valueReader) typeError(t Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t)
}
func (vr *valueReader) invalidDocumentLengthError() error {
return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset)
}
func (vr *valueReader) ensureElementValue(t Type, destination mode, callerName string) error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
if vr.stack[vr.frame].vType != t {
return vr.typeError(t)
}
default:
return vr.invalidTransitionErr(destination, callerName, []mode{mElement, mValue})
}
return nil
}
func (vr *valueReader) Type() Type {
return vr.stack[vr.frame].vType
}
// peekNextValueSize returns the length of the next value in the stream without
// offsetting the reader position.
func peekNextValueSize(vr *valueReader) (int32, error) {
var length int32
var err error
switch vr.stack[vr.frame].vType {
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
length, err = vr.peekLength()
case TypeBinary:
length, err = vr.peekLength()
length += 4 + 1 // binary length + subtype byte
case TypeBoolean:
length = 1
case TypeDBPointer:
length, err = vr.peekLength()
length += 4 + 12 // string length + ObjectID length
case TypeDateTime, TypeDouble, TypeInt64, TypeTimestamp:
length = 8
case TypeDecimal128:
length = 16
case TypeInt32:
length = 4
case TypeJavaScript, TypeString, TypeSymbol:
length, err = vr.peekLength()
length += 4
case TypeMaxKey, TypeMinKey, TypeNull, TypeUndefined:
length = 0
case TypeObjectID:
length = 12
case TypeRegex:
length, err = vr.src.regexLength()
default:
return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType)
}
return length, err
}
// readBytes tries to grab the next n bytes zero-allocation using peek+discard.
// If peek fails (e.g. bufio buffer full), it falls back to io.ReadFull.
func readBytes(src byteSrc, n int) ([]byte, error) {
if src.streamable() {
data := make([]byte, n)
if _, err := src.readExact(data); err != nil {
return nil, err
}
return data, nil
}
// Zero-allocation path.
buf, err := src.peek(n)
if err != nil {
return nil, err
}
_, _ = src.discard(n) // Discard the bytes from the source.
return buf, nil
}
// readBytesValueReader returns a subslice [offset, offset+length) or EOF.
func (vr *valueReader) readBytes(n int32) ([]byte, error) {
if n < 0 {
return nil, fmt.Errorf("invalid length: %d", n)
}
return readBytes(vr.src, int(n))
}
//nolint:unparam
func (vr *valueReader) readValueBytes(dst []byte) (Type, []byte, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
length, err := vr.peekLength()
if err != nil {
return 0, nil, err
}
b, err := vr.readBytes(length)
return Type(0), append(dst, b...), err
case mElement, mValue:
t := vr.stack[vr.frame].vType
length, err := peekNextValueSize(vr)
if err != nil {
return t, dst, err
}
b, err := vr.readBytes(length)
if err := vr.pop(); err != nil {
return Type(0), nil, err
}
return t, append(dst, b...), err
default:
return Type(0), nil, vr.invalidTransitionErr(0, "readValueBytes", []mode{mElement, mValue})
}
}
func (vr *valueReader) Skip() error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
default:
return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
length, err := peekNextValueSize(vr)
if err != nil {
return err
}
_, err = vr.src.discard(int(length))
if err != nil {
return err
}
return vr.pop()
}
// ReadArray returns an ArrayReader for the next BSON array in the valueReader
// source, advancing the reader position to the end of the array.
func (vr *valueReader) ReadArray() (ArrayReader, error) {
if err := vr.ensureElementValue(TypeArray, mArray, "ReadArray"); err != nil {
return nil, err
}
// Push a new frame for the array.
vr.advanceFrame()
// Read the 4-byte length.
size, err := vr.readLength()
if err != nil {
return nil, err
}
// Compute the end position: current position + total size - length.
vr.stack[vr.frame].mode = mArray
vr.stack[vr.frame].end = vr.src.pos() + int64(size) - 4
return vr, nil
}
// ReadBinary reads a BSON binary value, returning the byte slice and the
// type of the binary data (0x02 for old binary, 0x00 for new binary, etc.),
// advancing the reader position to the end of the binary value.
func (vr *valueReader) ReadBinary() ([]byte, byte, error) {
if err := vr.ensureElementValue(TypeBinary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
length, err := vr.readLength()
if err != nil {
return nil, 0, err
}
btype, err := vr.readByte()
if err != nil {
return nil, 0, err
}
// Check length in case it is an old binary without a length.
if btype == 0x02 && length > 4 {
length, err = vr.readLength()
if err != nil {
return nil, 0, err
}
}
b, err := vr.readBytes(length)
if err != nil {
return nil, 0, err
}
// copy so user doesnt share underlying buffer
cp := make([]byte, len(b))
copy(cp, b)
if err := vr.pop(); err != nil {
return nil, 0, err
}
return cp, btype, nil
}
// ReadBoolean reads a BSON boolean value, returning true or false, advancing
// the reader position to the end of the boolean value.
func (vr *valueReader) ReadBoolean() (bool, error) {
if err := vr.ensureElementValue(TypeBoolean, 0, "ReadBoolean"); err != nil {
return false, err
}
b, err := vr.readByte()
if err != nil {
return false, err
}
if b > 1 {
return false, fmt.Errorf("invalid byte for boolean, %b", b)
}
if err := vr.pop(); err != nil {
return false, err
}
return b == 1, nil
}
// ReadDocument reads a BSON embedded document, returning a DocumentReader,
// advancing the reader position to the end of the document.
func (vr *valueReader) ReadDocument() (DocumentReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
length, err := vr.readLength()
if err != nil {
return nil, err
}
if length <= 4 {
return nil, fmt.Errorf("invalid string length: %d", length)
}
vr.stack[vr.frame].end = int64(length) + vr.src.pos() - 4
return vr, nil
case mElement, mValue:
if vr.stack[vr.frame].vType != TypeEmbeddedDocument {
return nil, vr.typeError(TypeEmbeddedDocument)
}
default:
return nil, vr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
vr.advanceFrame()
size, err := vr.readLength()
if err != nil {
return nil, err
}
vr.stack[vr.frame].mode = mDocument
vr.stack[vr.frame].end = int64(size) + vr.src.pos() - 4
return vr, nil
}
// ReadCodeWithScope reads a BSON CodeWithScope value, returning the code as a
// string, advancing the reader position to the end of the CodeWithScope value.
func (vr *valueReader) ReadCodeWithScope() (string, DocumentReader, error) {
if err := vr.ensureElementValue(TypeCodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
totalLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
strLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
if strLength <= 0 {
return "", nil, fmt.Errorf("invalid string length: %d", strLength)
}
buf, err := vr.readBytes(strLength)
if err != nil {
return "", nil, err
}
code := string(buf[:len(buf)-1])
vr.advanceFrame()
// Use readLength to ensure that we are not out of bounds.
size, err := vr.readLength()
if err != nil {
return "", nil, err
}
vr.stack[vr.frame].mode = mCodeWithScope
vr.stack[vr.frame].end = vr.src.pos() + int64(size) - 4
// The total length should equal:
// 4 (total length) + strLength + 4 (the length of str itself) + (document length)
componentsLength := int64(4+strLength+4) + int64(size)
if int64(totalLength) != componentsLength {
return "", nil, fmt.Errorf(
"length of CodeWithScope does not match lengths of components; total: %d; components: %d",
totalLength, componentsLength,
)
}
return code, vr, nil
}
// ReadDBPointer reads a BSON DBPointer value, returning the namespace, the
// object ID, and an error if any, advancing the reader position to the end of
// the DBPointer value.
func (vr *valueReader) ReadDBPointer() (string, ObjectID, error) {
if err := vr.ensureElementValue(TypeDBPointer, 0, "ReadDBPointer"); err != nil {
return "", ObjectID{}, err
}
ns, err := vr.readString()
if err != nil {
return "", ObjectID{}, err
}
oidBytes, err := vr.readBytes(12)
if err != nil {
return "", ObjectID{}, err
}
var oid ObjectID
copy(oid[:], oidBytes)
if err := vr.pop(); err != nil {
return "", ObjectID{}, err
}
return ns, oid, nil
}
// ReadDateTime reads a BSON DateTime value, advancing the reader position to
// the end of the DateTime value.
func (vr *valueReader) ReadDateTime() (int64, error) {
if err := vr.ensureElementValue(TypeDateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
i, err := vr.readi64()
if err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return i, nil
}
// ReadDecimal128 reads a BSON Decimal128 value, advancing the reader
// to the end of the Decimal128 value.
func (vr *valueReader) ReadDecimal128() (Decimal128, error) {
if err := vr.ensureElementValue(TypeDecimal128, 0, "ReadDecimal128"); err != nil {
return Decimal128{}, err
}
b, err := vr.readBytes(16)
if err != nil {
return Decimal128{}, err
}
l := binary.LittleEndian.Uint64(b[0:8])
h := binary.LittleEndian.Uint64(b[8:16])
if err := vr.pop(); err != nil {
return Decimal128{}, err
}
return NewDecimal128(h, l), nil
}
// ReadDouble reads a BSON double value, advancing the reader position to
// to the end of the double value.
func (vr *valueReader) ReadDouble() (float64, error) {
if err := vr.ensureElementValue(TypeDouble, 0, "ReadDouble"); err != nil {
return 0, err
}
u, err := vr.readu64()
if err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return math.Float64frombits(u), nil
}
// ReadInt32 reads a BSON int32 value, advancing the reader position to the end
// of the int32 value.
func (vr *valueReader) ReadInt32() (int32, error) {
if err := vr.ensureElementValue(TypeInt32, 0, "ReadInt32"); err != nil {
return 0, err
}
i, err := vr.readi32()
if err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return i, nil
}
// ReadInt64 reads a BSON int64 value, advancing the reader position to the end
// of the int64 value.
func (vr *valueReader) ReadInt64() (int64, error) {
if err := vr.ensureElementValue(TypeInt64, 0, "ReadInt64"); err != nil {
return 0, err
}
i, err := vr.readi64()
if err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return i, nil
}
// ReadJavascript reads a BSON JavaScript value, advancing the reader
// to the end of the JavaScript value.
func (vr *valueReader) ReadJavascript() (string, error) {
if err := vr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
s, err := vr.readString()
if err != nil {
return "", err
}
if err := vr.pop(); err != nil {
return "", err
}
return s, nil
}
// ReadMaxKey reads a BSON MaxKey value, advancing the reader position to the
// end of the MaxKey value.
func (vr *valueReader) ReadMaxKey() error {
if err := vr.ensureElementValue(TypeMaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
return vr.pop()
}
// ReadMinKey reads a BSON MinKey value, advancing the reader position to the
// end of the MinKey value.
func (vr *valueReader) ReadMinKey() error {
if err := vr.ensureElementValue(TypeMinKey, 0, "ReadMinKey"); err != nil {
return err
}
return vr.pop()
}
// REadNull reads a BSON Null value, advancing the reader position to the
// end of the Null value.
func (vr *valueReader) ReadNull() error {
if err := vr.ensureElementValue(TypeNull, 0, "ReadNull"); err != nil {
return err
}
return vr.pop()
}
// ReadObjectID reads a BSON ObjectID value, advancing the reader to the end of
// the ObjectID value.
func (vr *valueReader) ReadObjectID() (ObjectID, error) {
if err := vr.ensureElementValue(TypeObjectID, 0, "ReadObjectID"); err != nil {
return ObjectID{}, err
}
oidBytes, err := vr.readBytes(12)
if err != nil {
return ObjectID{}, err
}
var oid ObjectID
copy(oid[:], oidBytes)
if err := vr.pop(); err != nil {
return ObjectID{}, err
}
return oid, nil
}
// ReadRegex reads a BSON Regex value, advancing the reader position to the
// regex value.
func (vr *valueReader) ReadRegex() (string, string, error) {
if err := vr.ensureElementValue(TypeRegex, 0, "ReadRegex"); err != nil {
return "", "", err
}
pattern, err := vr.readCString()
if err != nil {
return "", "", err
}
options, err := vr.readCString()
if err != nil {
return "", "", err
}
if err := vr.pop(); err != nil {
return "", "", err
}
return pattern, options, nil
}
// ReadString reads a BSON String value, advancing the reader position to the
// end of the String value.
func (vr *valueReader) ReadString() (string, error) {
if err := vr.ensureElementValue(TypeString, 0, "ReadString"); err != nil {
return "", err
}
s, err := vr.readString()
if err != nil {
return "", err
}
if err := vr.pop(); err != nil {
return "", err
}
return s, nil
}
// ReadSymbol reads a BSON Symbol value, advancing the reader position to the
// end of the Symbol value.
func (vr *valueReader) ReadSymbol() (string, error) {
if err := vr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil {
return "", err
}
s, err := vr.readString()
if err != nil {
return "", err
}
if err := vr.pop(); err != nil {
return "", err
}
return s, nil
}
// ReadTimestamp reads a BSON Timestamp value, advancing the reader to the end
// of the Timestamp value.
func (vr *valueReader) ReadTimestamp() (uint32, uint32, error) {
if err := vr.ensureElementValue(TypeTimestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
i, err := vr.readu32()
if err != nil {
return 0, 0, err
}
t, err := vr.readu32()
if err != nil {
return 0, 0, err
}
if err := vr.pop(); err != nil {
return 0, 0, err
}
return t, i, nil
}
// ReadUndefined reads a BSON Undefined value, advancing the reader position
// to the end of the Undefined value.
func (vr *valueReader) ReadUndefined() error {
if err := vr.ensureElementValue(TypeUndefined, 0, "ReadUndefined"); err != nil {
return err
}
return vr.pop()
}
// ReadElement reads the next element in the BSON document, advancing the
// reader position to the end of the element.
func (vr *valueReader) ReadElement() (string, ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, vr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
t, err := vr.readByte()
if err != nil {
return "", nil, err
}
if t == 0 {
if vr.src.pos() != vr.stack[vr.frame].end {
return "", nil, vr.invalidDocumentLengthError()
}
_ = vr.pop() // Ignore the error because the call here never reads from the underlying reader.
return "", nil, ErrEOD
}
name, err := vr.readCString()
if err != nil {
return "", nil, err
}
vr.advanceFrame()
vr.stack[vr.frame].mode = mElement
vr.stack[vr.frame].vType = Type(t)
return name, vr, nil
}
// ReadValue reads the next value in the BSON array, advancing the to the end of
// the value.
func (vr *valueReader) ReadValue() (ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mArray:
default:
return nil, vr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := vr.readByte()
if err != nil {
return nil, err
}
if t == 0 {
if vr.src.pos() != vr.stack[vr.frame].end {
return nil, vr.invalidDocumentLengthError()
}
_ = vr.pop() // Ignore the error because the call here never reads from the underlying reader.
return nil, ErrEOA
}
_, err = vr.src.readSlice(0x00)
if err != nil {
return nil, err
}
vr.advanceFrame()
vr.stack[vr.frame].mode = mValue
vr.stack[vr.frame].vType = Type(t)
return vr, nil
}
func (vr *valueReader) readByte() (byte, error) {
b, err := vr.src.ReadByte()
if err != nil {
return 0x0, err
}
return b, nil
}
func (vr *valueReader) readCString() (string, error) {
data, err := vr.src.readSlice(0x00)
if err != nil {
return "", err
}
return string(data[:len(data)-1]), nil
}
func (vr *valueReader) readString() (string, error) {
length, err := vr.readLength()
if err != nil {
return "", err
}
if length <= 0 {
return "", fmt.Errorf("invalid string length: %d", length)
}
raw, err := readBytes(vr.src, int(length))
if err != nil {
return "", err
}
// Check that the last byte is the NUL terminator.
if raw[len(raw)-1] != 0x00 {
return "", fmt.Errorf("string does not end with null byte, but with %v", raw[len(raw)-1])
}
// Convert and strip the trailing NUL.
return string(raw[:len(raw)-1]), nil
}
func (vr *valueReader) peekLength() (int32, error) {
buf, err := vr.src.peek(4)
if err != nil {
return 0, err
}
return int32(binary.LittleEndian.Uint32(buf)), nil
}
func (vr *valueReader) readLength() (int32, error) {
return vr.readi32()
}
func (vr *valueReader) readi32() (int32, error) {
raw, err := readBytes(vr.src, 4)
if err != nil {
return 0, err
}
return int32(binary.LittleEndian.Uint32(raw)), nil
}
func (vr *valueReader) readu32() (uint32, error) {
raw, err := readBytes(vr.src, 4)
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint32(raw), nil
}
func (vr *valueReader) readi64() (int64, error) {
raw, err := readBytes(vr.src, 8)
if err != nil {
return 0, err
}
return int64(binary.LittleEndian.Uint64(raw)), nil
}
func (vr *valueReader) readu64() (uint64, error) {
raw, err := readBytes(vr.src, 8)
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint64(raw), nil
}
+600
View File
@@ -0,0 +1,600 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"io"
"math"
"strconv"
"strings"
"sync"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
var _ ValueWriter = &valueWriter{}
var vwPool = sync.Pool{
New: func() any {
return new(valueWriter)
},
}
func putValueWriter(vw *valueWriter) {
if vw != nil {
vw.w = nil // don't leak the writer
vwPool.Put(vw)
}
}
var documentWriterPool = sync.Pool{
New: func() any {
return newDocumentWriter(nil)
},
}
func getDocumentWriter(w io.Writer) *valueWriter {
vw := documentWriterPool.Get().(*valueWriter)
vw.reset(vw.buf)
vw.buf = vw.buf[:0]
vw.w = w
return vw
}
func putDocumentWriter(vw *valueWriter) {
if vw != nil {
vw.w = nil // don't leak the writer
documentWriterPool.Put(vw)
}
}
// This is here so that during testing we can change it and not require
// allocating a 4GB slice.
var maxSize = math.MaxInt32
type errMaxDocumentSizeExceeded struct {
size int64
}
func (mdse errMaxDocumentSizeExceeded) Error() string {
return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
}
type vwMode int
const (
_ vwMode = iota
vwTopLevel
vwDocument
vwArray
vwValue
vwElement
vwCodeWithScope
)
func (vm vwMode) String() string {
var str string
switch vm {
case vwTopLevel:
str = "TopLevel"
case vwDocument:
str = "DocumentMode"
case vwArray:
str = "ArrayMode"
case vwValue:
str = "ValueMode"
case vwElement:
str = "ElementMode"
case vwCodeWithScope:
str = "CodeWithScopeMode"
default:
str = "UnknownMode"
}
return str
}
type vwState struct {
mode mode
key string
arrkey int
start int32
}
type valueWriter struct {
w io.Writer
buf []byte
stack []vwState
frame int64
}
func (vw *valueWriter) advanceFrame() {
vw.frame++
if vw.frame >= int64(len(vw.stack)) {
vw.stack = append(vw.stack, vwState{})
}
}
func (vw *valueWriter) push(m mode) {
vw.advanceFrame()
// Clean the stack
vw.stack[vw.frame] = vwState{mode: m}
switch m {
case mDocument, mArray, mCodeWithScope:
vw.reserveLength() // WARN: this is not needed
}
}
func (vw *valueWriter) reserveLength() {
vw.stack[vw.frame].start = int32(len(vw.buf))
vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
}
func (vw *valueWriter) pop() {
switch vw.stack[vw.frame].mode {
case mElement, mValue:
vw.frame--
case mDocument, mArray, mCodeWithScope:
vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
// NewDocumentWriter creates a ValueWriter that writes BSON to w.
//
// This ValueWriter will only write entire documents to the io.Writer and it
// will buffer the document as it is built.
func NewDocumentWriter(w io.Writer) ValueWriter {
return newDocumentWriter(w)
}
func newDocumentWriter(w io.Writer) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.w = w
vw.stack = stack
return vw
}
func newValueWriterFromSlice(buf []byte) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.stack = stack
vw.buf = buf
return vw
}
func (vw *valueWriter) reset(buf []byte) {
if vw.stack == nil {
vw.stack = make([]vwState, 1, 5)
}
vw.stack = vw.stack[:1]
vw.stack[0] = vwState{mode: mTopLevel}
vw.buf = buf
vw.frame = 0
vw.w = nil
}
func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vw.stack[vw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if vw.frame != 0 {
te.parent = vw.stack[vw.frame-1].mode
}
return te
}
func (vw *valueWriter) writeElementHeader(t Type, destination mode, callerName string, addmodes ...mode) error {
frame := &vw.stack[vw.frame]
switch frame.mode {
case mElement:
key := frame.key
if !isValidCString(key) {
return errors.New("BSON element key cannot contain null bytes")
}
vw.appendHeader(t, key)
case mValue:
vw.appendIntHeader(t, frame.arrkey)
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return vw.invalidTransitionError(destination, callerName, modes)
}
return nil
}
func (vw *valueWriter) writeValueBytes(t Type, b []byte) error {
if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
return err
}
vw.buf = append(vw.buf, b...)
vw.pop()
return nil
}
func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
if err := vw.writeElementHeader(TypeArray, mArray, "WriteArray"); err != nil {
return nil, err
}
vw.push(mArray)
return vw, nil
}
func (vw *valueWriter) WriteBinary(b []byte) error {
return vw.WriteBinaryWithSubtype(b, 0x00)
}
func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := vw.writeElementHeader(TypeBinary, mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteBoolean(b bool) error {
if err := vw.writeElementHeader(TypeBoolean, mode(0), "WriteBoolean"); err != nil {
return err
}
vw.buf = bsoncore.AppendBoolean(vw.buf, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := vw.writeElementHeader(TypeCodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
// CodeWithScope is a different than other types because we need an extra
// frame on the stack. In the EndDocument code, we write the document
// length, pop, write the code with scope length, and pop. To simplify the
// pop code, we push a spacer frame that we'll always jump over.
vw.push(mCodeWithScope)
vw.buf = bsoncore.AppendString(vw.buf, code)
vw.push(mSpacer)
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteDBPointer(ns string, oid ObjectID) error {
if err := vw.writeElementHeader(TypeDBPointer, mode(0), "WriteDBPointer"); err != nil {
return err
}
vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDateTime(dt int64) error {
if err := vw.writeElementHeader(TypeDateTime, mode(0), "WriteDateTime"); err != nil {
return err
}
vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDecimal128(d128 Decimal128) error {
if err := vw.writeElementHeader(TypeDecimal128, mode(0), "WriteDecimal128"); err != nil {
return err
}
h, l := d128.GetBytes()
vw.buf = bsoncore.AppendDecimal128(vw.buf, h, l)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDouble(f float64) error {
if err := vw.writeElementHeader(TypeDouble, mode(0), "WriteDouble"); err != nil {
return err
}
vw.buf = bsoncore.AppendDouble(vw.buf, f)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt32(i32 int32) error {
if err := vw.writeElementHeader(TypeInt32, mode(0), "WriteInt32"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt32(vw.buf, i32)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt64(i64 int64) error {
if err := vw.writeElementHeader(TypeInt64, mode(0), "WriteInt64"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt64(vw.buf, i64)
vw.pop()
return nil
}
func (vw *valueWriter) WriteJavascript(code string) error {
if err := vw.writeElementHeader(TypeJavaScript, mode(0), "WriteJavascript"); err != nil {
return err
}
vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
vw.pop()
return nil
}
func (vw *valueWriter) WriteMaxKey() error {
if err := vw.writeElementHeader(TypeMaxKey, mode(0), "WriteMaxKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteMinKey() error {
if err := vw.writeElementHeader(TypeMinKey, mode(0), "WriteMinKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteNull() error {
if err := vw.writeElementHeader(TypeNull, mode(0), "WriteNull"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteObjectID(oid ObjectID) error {
if err := vw.writeElementHeader(TypeObjectID, mode(0), "WriteObjectID"); err != nil {
return err
}
vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteRegex(pattern string, options string) error {
if !isValidCString(pattern) || !isValidCString(options) {
return errors.New("BSON regex values cannot contain null bytes")
}
if err := vw.writeElementHeader(TypeRegex, mode(0), "WriteRegex"); err != nil {
return err
}
vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
vw.pop()
return nil
}
func (vw *valueWriter) WriteString(s string) error {
if err := vw.writeElementHeader(TypeString, mode(0), "WriteString"); err != nil {
return err
}
vw.buf = bsoncore.AppendString(vw.buf, s)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
if vw.stack[vw.frame].mode == mTopLevel {
vw.reserveLength()
return vw, nil
}
if err := vw.writeElementHeader(TypeEmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteSymbol(symbol string) error {
if err := vw.writeElementHeader(TypeSymbol, mode(0), "WriteSymbol"); err != nil {
return err
}
vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
vw.pop()
return nil
}
func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := vw.writeElementHeader(TypeTimestamp, mode(0), "WriteTimestamp"); err != nil {
return err
}
vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
vw.pop()
return nil
}
func (vw *valueWriter) WriteUndefined() error {
if err := vw.writeElementHeader(TypeUndefined, mode(0), "WriteUndefined"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
}
vw.push(mElement)
vw.stack[vw.frame].key = key
return vw, nil
}
func (vw *valueWriter) WriteDocumentEnd() error {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
if vw.stack[vw.frame].mode == mTopLevel {
if err = vw.Flush(); err != nil {
return err
}
}
vw.pop()
if vw.stack[vw.frame].mode == mCodeWithScope {
// We ignore the error here because of the guarantee of writeLength.
// See the docs for writeLength for more info.
_ = vw.writeLength()
vw.pop()
}
return nil
}
func (vw *valueWriter) Flush() error {
if vw.w == nil {
return nil
}
if _, err := vw.w.Write(vw.buf); err != nil {
return err
}
// reset buffer
vw.buf = vw.buf[:0]
return nil
}
func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
if vw.stack[vw.frame].mode != mArray {
return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
}
arrkey := vw.stack[vw.frame].arrkey
vw.stack[vw.frame].arrkey++
vw.push(mValue)
vw.stack[vw.frame].arrkey = arrkey
return vw, nil
}
func (vw *valueWriter) WriteArrayEnd() error {
if vw.stack[vw.frame].mode != mArray {
return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
vw.pop()
return nil
}
// NOTE: We assume that if we call writeLength more than once the same function
// within the same function without altering the vw.buf that this method will
// not return an error. If this changes ensure that the following methods are
// updated:
//
// - WriteDocumentEnd
func (vw *valueWriter) writeLength() error {
length := len(vw.buf)
if length > maxSize {
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
}
frame := &vw.stack[vw.frame]
length -= int(frame.start)
start := frame.start
_ = vw.buf[start+3] // BCE
vw.buf[start+0] = byte(length)
vw.buf[start+1] = byte(length >> 8)
vw.buf[start+2] = byte(length >> 16)
vw.buf[start+3] = byte(length >> 24)
return nil
}
func isValidCString(cs string) bool {
// Disallow the zero byte in a cstring because the zero byte is used as the
// terminating character.
//
// It's safe to check bytes instead of runes because all multibyte UTF-8
// code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e.
// 0) will never be part of a multibyte UTF-8 code point. This logic is the
// same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be
// inlined.
//
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127
return strings.IndexByte(cs, 0) == -1
}
// appendHeader is the same as bsoncore.AppendHeader but does not check if the
// key is a valid C string since the caller has already checked for that.
//
// The caller of this function must check if key is a valid C string.
func (vw *valueWriter) appendHeader(t Type, key string) {
vw.buf = bsoncore.AppendType(vw.buf, bsoncore.Type(t))
vw.buf = append(vw.buf, key...)
vw.buf = append(vw.buf, 0x00)
}
func (vw *valueWriter) appendIntHeader(t Type, key int) {
vw.buf = bsoncore.AppendType(vw.buf, bsoncore.Type(t))
vw.buf = strconv.AppendInt(vw.buf, int64(key), 10)
vw.buf = append(vw.buf, 0x00)
}
+268
View File
@@ -0,0 +1,268 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/binary"
"errors"
"fmt"
"math"
)
// BSON binary vector types as described in https://bsonspec.org/spec.html.
const (
Int8Vector byte = 0x03
Float32Vector byte = 0x27
PackedBitVector byte = 0x10
)
// These are vector conversion errors.
var (
errInsufficientVectorData = errors.New("insufficient data")
errNonZeroVectorPadding = errors.New("padding must be 0")
errVectorPaddingTooLarge = errors.New("padding cannot be larger than 7")
)
type vectorTypeError struct {
Method string
Type byte
}
// Error implements the error interface.
func (vte vectorTypeError) Error() string {
t := "invalid"
switch vte.Type {
case Int8Vector:
t = "int8"
case Float32Vector:
t = "float32"
case PackedBitVector:
t = "packed bit"
}
return fmt.Sprintf("cannot call %s, on a type %s vector", vte.Method, t)
}
// Vector represents a densely packed array of numbers / bits.
type Vector struct {
dType byte
int8Data []int8
float32Data []float32
bitData []byte
bitPadding uint8
}
// Type returns the vector type.
func (v Vector) Type() byte {
return v.dType
}
// Int8 returns the int8 slice hold by the vector.
// It panics if v is not an int8 vector.
func (v Vector) Int8() []int8 {
d, ok := v.Int8OK()
if !ok {
panic(vectorTypeError{"bson.Vector.Int8", v.dType})
}
return d
}
// Int8OK is the same as Int8, but returns a boolean instead of panicking.
func (v Vector) Int8OK() ([]int8, bool) {
if v.dType != Int8Vector {
return nil, false
}
return v.int8Data, true
}
// Float32 returns the float32 slice hold by the vector.
// It panics if v is not a float32 vector.
func (v Vector) Float32() []float32 {
d, ok := v.Float32OK()
if !ok {
panic(vectorTypeError{"bson.Vector.Float32", v.dType})
}
return d
}
// Float32OK is the same as Float32, but returns a boolean instead of panicking.
func (v Vector) Float32OK() ([]float32, bool) {
if v.dType != Float32Vector {
return nil, false
}
return v.float32Data, true
}
// PackedBit returns the byte slice representing the binary quantized (packed bit) vector and the byte padding, which
// is the number of bits in the final byte that are to be ignored.
// It panics if v is not a packed bit vector.
func (v Vector) PackedBit() ([]byte, uint8) {
d, p, ok := v.PackedBitOK()
if !ok {
panic(vectorTypeError{"bson.Vector.PackedBit", v.dType})
}
return d, p
}
// PackedBitOK is the same as PackedBit, but returns a boolean instead of panicking.
func (v Vector) PackedBitOK() ([]byte, uint8, bool) {
if v.dType != PackedBitVector {
return nil, 0, false
}
return v.bitData, v.bitPadding, true
}
// Binary returns the BSON Binary representation of the Vector.
func (v Vector) Binary() Binary {
switch v.Type() {
case Int8Vector:
return binaryFromInt8Vector(v.Int8())
case Float32Vector:
return binaryFromFloat32Vector(v.Float32())
case PackedBitVector:
return binaryFromBitVector(v.PackedBit())
default:
panic(fmt.Sprintf("invalid Vector data type: %d", v.dType))
}
}
func binaryFromInt8Vector(v []int8) Binary {
data := make([]byte, len(v)+2)
data[0] = Int8Vector
data[1] = 0
for i, e := range v {
data[i+2] = byte(e)
}
return Binary{
Subtype: TypeBinaryVector,
Data: data,
}
}
func binaryFromFloat32Vector(v []float32) Binary {
data := make([]byte, 2, len(v)*4+2)
data[0] = Float32Vector
data[1] = 0
var a [4]byte
for _, e := range v {
binary.LittleEndian.PutUint32(a[:], math.Float32bits(e))
data = append(data, a[:]...)
}
return Binary{
Subtype: TypeBinaryVector,
Data: data,
}
}
func binaryFromBitVector(bits []byte, padding uint8) Binary {
data := make([]byte, len(bits)+2)
data[0] = PackedBitVector
data[1] = padding
copy(data[2:], bits)
return Binary{
Subtype: TypeBinaryVector,
Data: data,
}
}
// NewVector constructs a Vector from a slice of int8 or float32.
func NewVector[T int8 | float32](data []T) Vector {
var v Vector
switch a := any(data).(type) {
case []int8:
v.dType = Int8Vector
v.int8Data = make([]int8, len(data))
copy(v.int8Data, a)
case []float32:
v.dType = Float32Vector
v.float32Data = make([]float32, len(data))
copy(v.float32Data, a)
default:
panic(fmt.Errorf("unsupported type %T", data))
}
return v
}
// NewPackedBitVector constructs a Vector from a byte slice and a value of byte padding.
func NewPackedBitVector(bits []byte, padding uint8) (Vector, error) {
var v Vector
if padding > 7 {
return v, errVectorPaddingTooLarge
}
if padding > 0 && len(bits) == 0 {
return v, errNonZeroVectorPadding
}
v.dType = PackedBitVector
v.bitData = make([]byte, len(bits))
copy(v.bitData, bits)
v.bitPadding = padding
return v, nil
}
// NewVectorFromBinary unpacks a BSON Binary into a Vector.
func NewVectorFromBinary(b Binary) (Vector, error) {
var v Vector
if b.Subtype != TypeBinaryVector {
return v, errors.New("not a vector")
}
if len(b.Data) < 2 {
return v, errInsufficientVectorData
}
switch t := b.Data[0]; t {
case Int8Vector:
return newInt8Vector(b.Data[1:])
case Float32Vector:
return newFloat32Vector(b.Data[1:])
case PackedBitVector:
return newBitVector(b.Data[1:])
default:
return v, fmt.Errorf("invalid Vector data type: %d", t)
}
}
func newInt8Vector(b []byte) (Vector, error) {
var v Vector
if len(b) == 0 {
return v, errInsufficientVectorData
}
if padding := b[0]; padding > 0 {
return v, errNonZeroVectorPadding
}
s := make([]int8, 0, len(b)-1)
for i := 1; i < len(b); i++ {
s = append(s, int8(b[i]))
}
return NewVector(s), nil
}
func newFloat32Vector(b []byte) (Vector, error) {
var v Vector
if len(b) == 0 {
return v, errInsufficientVectorData
}
if padding := b[0]; padding > 0 {
return v, errNonZeroVectorPadding
}
l := (len(b) - 1) / 4
if l*4 != len(b)-1 {
return v, errInsufficientVectorData
}
s := make([]float32, 0, l)
for i := 1; i < len(b); i += 4 {
s = append(s, math.Float32frombits(binary.LittleEndian.Uint32(b[i:i+4])))
}
return NewVector(s), nil
}
func newBitVector(b []byte) (Vector, error) {
if len(b) == 0 {
return Vector{}, errInsufficientVectorData
}
return NewPackedBitVector(b[1:], b[0])
}
+61
View File
@@ -0,0 +1,61 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
// ArrayWriter is the interface used to create a BSON or BSON adjacent array.
// Callers must ensure they call WriteArrayEnd when they have finished creating
// the array.
type ArrayWriter interface {
WriteArrayElement() (ValueWriter, error)
WriteArrayEnd() error
}
// DocumentWriter is the interface used to create a BSON or BSON adjacent
// document. Callers must ensure they call WriteDocumentEnd when they have
// finished creating the document.
type DocumentWriter interface {
WriteDocumentElement(string) (ValueWriter, error)
WriteDocumentEnd() error
}
// ValueWriter is the interface used to write BSON values. Implementations of
// this interface handle creating BSON or BSON adjacent representations of the
// values.
type ValueWriter interface {
WriteArray() (ArrayWriter, error)
WriteBinary(b []byte) error
WriteBinaryWithSubtype(b []byte, btype byte) error
WriteBoolean(bool) error
WriteCodeWithScope(code string) (DocumentWriter, error)
WriteDBPointer(ns string, oid ObjectID) error
WriteDateTime(dt int64) error
WriteDecimal128(Decimal128) error
WriteDouble(float64) error
WriteInt32(int32) error
WriteInt64(int64) error
WriteJavascript(code string) error
WriteMaxKey() error
WriteMinKey() error
WriteNull() error
WriteObjectID(ObjectID) error
WriteRegex(pattern, options string) error
WriteString(string) error
WriteDocument() (DocumentWriter, error)
WriteSymbol(symbol string) error
WriteTimestamp(t, i uint32) error
WriteUndefined() error
}
// sliceWriter allows a pointer to a slice of bytes to be used as an io.Writer.
type sliceWriter []byte
// Write writes the bytes to the underlying slice.
func (sw *sliceWriter) Write(p []byte) (int, error) {
written := len(p)
*sw = append(*sw, p...)
return written, nil
}
+135
View File
@@ -0,0 +1,135 @@
// Copyright (C) MongoDB, Inc. 2025-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package binaryutil provides utility functions for working with binary data.
package binaryutil
import (
"bytes"
"encoding/binary"
)
// Append32 appends a uint32 or int32 value to dst in little-endian byte order.
// Byte shifting is done directly to prevent overflow security errors, in
// compliance with gosec G115.
//
// See: https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/encoding/binary/binary.go;l=92
func Append32[T ~uint32 | ~int32](dst []byte, v T) []byte {
return append(dst,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
)
}
// Append64 appends a uint64 or int64 value to dst in little-endian byte order.
// Byte shifting is done directly to prevent overflow security errors, in
// compliance with gosec G115.
//
// See: https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/encoding/binary/binary.go;l=119
func Append64[T ~uint64 | ~int64](dst []byte, v T) []byte {
return append(dst,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
byte(v>>32),
byte(v>>40),
byte(v>>48),
byte(v>>56),
)
}
// ReadU32 reads a uint32 from src in little-endian byte order. ReadU32 and
// ReadI32 are separate functions to avoid unsafe casting between unsigned and
// signed integers.
func ReadU32(src []byte) (uint32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
return binary.LittleEndian.Uint32(src), src[4:], true
}
// ReadI32 reads an int32 from src in little-endian byte order.
// Byte shifting is done directly to prevent overflow security errors, in
// compliance with gosec G115. ReadU32 and ReadI32 are separate functions to
// avoid unsafe casting between unsigned and signed integers.
//
// See: https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/encoding/binary/binary.go;l=79
func ReadI32(src []byte) (int32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
_ = src[3] // bounds check hint to compiler
value := int32(src[0]) |
int32(src[1])<<8 |
int32(src[2])<<16 |
int32(src[3])<<24
return value, src[4:], true
}
// ReadU64 reads a uint64 from src in little-endian byte order. ReadU64 and
// ReadI64 are separate functions to avoid unsafe casting between unsigned and
// signed integers.
func ReadU64(src []byte) (uint64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
return binary.LittleEndian.Uint64(src), src[8:], true
}
// ReadI64 reads an int64 from src in little-endian byte order.
// Byte shifting is done directly to prevent overflow security errors, in
// compliance with gosec G115. ReadU64 and ReadI64 are separate functions to
// avoid unsafe casting between unsigned and signed integers.
//
// See: https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/encoding/binary/binary.go;l=101
func ReadI64(src []byte) (int64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
_ = src[7] // bounds check hint to compiler
value := int64(src[0]) |
int64(src[1])<<8 |
int64(src[2])<<16 |
int64(src[3])<<24 |
int64(src[4])<<32 |
int64(src[5])<<40 |
int64(src[6])<<48 |
int64(src[7])<<56
return value, src[8:], true
}
// ReadCStringBytes reads a null-terminated C string from src as a byte slice.
// This is the base implementation used by ReadCString to ensure a single source
// of truth for C string parsing logic.
func ReadCStringBytes(src []byte) ([]byte, []byte, bool) {
idx := bytes.IndexByte(src, 0x00)
if idx < 0 {
return nil, src, false
}
return src[:idx], src[idx+1:], true
}
// ReadCString reads a null-terminated C string from src as a string.
// It delegates to ReadCStringBytes to maintain a single source of truth for
// C string parsing logic.
func ReadCString(src []byte) (string, []byte, bool) {
cstr, rem, ok := ReadCStringBytes(src)
if !ok {
return "", src, false
}
return string(cstr), rem, true
}
+38
View File
@@ -0,0 +1,38 @@
// Copyright (C) MongoDB, Inc. 2026-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package binaryutil provides functions for reading binary primitives from
// byte slices. It is used internally for BSON parsing and wire protocol
// operations.
//
// The functions in this package are designed for use in BSON operations.
// Signed integer functions (ReadI32, ReadI64) use manual bit-shifting rather
// than encoding/binary to avoid unsafe signed/unsigned conversions and comply
// with gosec G115. Bounds-check elimination (BCE) hints help the compiler
// inline these functions.
//
// Benchmarking across different ARM64 architectures (Apple M-series)
// revealed non-deterministic performance discrepancies between using the
// "encoding/binary" standard library and manual bit-shifting ("straight-lining").
//
// Without Loss of Generality (WLOG), benchmarking observed that:
// - On Apple M1 Pro: Standard library (ReadU32) outperformed manual
// bit-shifting (ReadI32) by ~2x (~0.08ns vs ~0.16ns).
// - On Apple M4 Max: Manual bit-shifting (ReadI32) outperformed the
// standard library (ReadU32) by ~1.6x (~0.03ns vs ~0.05ns).
//
// Further testing showed that "straight-lining" the ReadU32 implementation
// to match ReadI32 normalized performance to ~0.03ns on the M4 Max, even
// though the generated assembly for both approaches is virtually equivalent.
//
// The generated assembly is nearly identical for both approaches. These
// sub-nanosecond variations likely stem from microarchitecture differences
// (instruction caching, branch prediction) rather than the code itself.
//
// Since network I/O dominates driver latency, these differences do not have a
// significant impact on driver performance. The implementation favors security
// compliance and readability over hardware-specific tuning.
package binaryutil
@@ -0,0 +1,40 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncoreutil
// Truncate truncates a given string for a certain width
func Truncate(str string, width int) string {
if width <= 0 {
return ""
}
if len(str) <= width {
return str
}
// Truncate the byte slice of the string to the given width.
newStr := str[:width]
// Check if the last byte is at the beginning of a multi-byte character.
// If it is, then remove the last byte.
if newStr[len(newStr)-1]&0xC0 == 0xC0 {
return newStr[:len(newStr)-1]
}
// Check if the last byte is a multi-byte character
if newStr[len(newStr)-1]&0xC0 == 0x80 {
// If it is, step back until you we are at the start of a character
for i := len(newStr) - 1; i >= 0; i-- {
if newStr[i]&0xC0 == 0xC0 {
// Truncate at the end of the character before the character we stepped back to
return newStr[:i]
}
}
}
return newStr
}
+117
View File
@@ -0,0 +1,117 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package decimal128
import (
"strconv"
)
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
const (
MaxDecimal128Exp = 6111
MinDecimal128Exp = -6176
)
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
div64 := uint64(div)
a := h >> 32
aq := a / div64
ar := a % div64
b := ar<<32 + h&(1<<32-1)
bq := b / div64
br := b % div64
c := br<<32 + l>>32
cq := c / div64
cr := c % div64
d := cr<<32 + l&(1<<32-1)
dq := d / div64
dr := d % div64
return (aq<<32 | bq), (cq<<32 | dq), uint32(dr)
}
// String returns a string representation of the decimal value.
func String(h, l uint64) string {
var posSign int // positive sign
var exp int // exponent
var high, low uint64 // significand high/low
if h>>63&1 == 0 {
posSign = 1
}
switch h >> 58 & (1<<5 - 1) {
case 0x1F:
return "NaN"
case 0x1E:
return "-Infinity"[posSign:]
}
low = l
if h>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
exp = int(h >> 47 & (1<<14 - 1))
// Spec says all of these values are out of range.
high, low = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
exp = int(h >> 49 & (1<<14 - 1))
high = h & (1<<49 - 1)
}
exp += MinDecimal128Exp
// Would be handled by the logic below, but that's trivial and common.
if high == 0 && low == 0 && exp == 0 {
return "-0"[posSign:]
}
var repr [48]byte // Loop 5 times over 9 digits plus dot, negative sign, and leading zero.
last := len(repr)
i := len(repr)
dot := len(repr) + exp
var rem uint32
Loop:
for d9 := 0; d9 < 5; d9++ {
high, low, rem = divmod(high, low, 1e9)
for d1 := 0; d1 < 9; d1++ {
// Handle "-0.0", "0.00123400", "-1.00E-6", "1.050E+3", etc.
if i < len(repr) && (dot == i || low == 0 && high == 0 && rem > 0 && rem < 10 && (dot < i-6 || exp > 0)) {
exp += len(repr) - i
i--
repr[i] = '.'
last = i - 1
dot = len(repr) // Unmark.
}
c := '0' + byte(rem%10)
rem /= 10
i--
repr[i] = c
// Handle "0E+3", "1E+3", etc.
if low == 0 && high == 0 && rem == 0 && i == len(repr)-1 && (dot < i-5 || exp > 0) {
last = i
break Loop
}
if c != '0' {
last = i
}
// Break early. Works without it, but why.
if dot > i && low == 0 && high == 0 && rem == 0 {
break Loop
}
}
}
repr[last-1] = '-'
last--
if exp > 0 {
return string(repr[last+posSign:]) + "E+" + strconv.Itoa(exp)
}
if exp < 0 {
return string(repr[last+posSign:]) + "E" + strconv.Itoa(exp)
}
return string(repr[last+posSign:])
}
+208
View File
@@ -0,0 +1,208 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"fmt"
"io"
"strconv"
"strings"
)
// NewArrayLengthError creates and returns an error for when the length of an array exceeds the
// bytes available.
func NewArrayLengthError(length, rem int) error {
return lengthError("array", length, rem)
}
// Array is a raw bytes representation of a BSON array.
type Array []byte
// NewArrayFromReader reads an array from r. This function will only validate the length is
// correct and that the array ends with a null byte.
func NewArrayFromReader(r io.Reader) (Array, error) {
return newBufferFromReader(r)
}
// Index searches for and retrieves the value at the given index. This method will panic if
// the array is invalid or if the index is out of bounds.
func (a Array) Index(index uint) Value {
value, err := a.IndexErr(index)
if err != nil {
panic(err)
}
return value
}
// IndexErr searches for and retrieves the value at the given index.
func (a Array) IndexErr(index uint) (Value, error) {
elem, err := indexErr(a, index)
if err != nil {
return Value{}, err
}
return elem.Value(), err
}
// DebugString outputs a human readable version of Array. It will attempt to stringify the
// valid components of the array even if the entire array is not valid.
func (a Array) DebugString() string {
if len(a) < 5 {
return "<malformed>"
}
var buf strings.Builder
buf.WriteString("Array")
length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length
buf.WriteByte('(')
buf.WriteString(strconv.Itoa(int(length)))
length -= 4
buf.WriteString(")[")
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
break
}
buf.WriteString(elem.Value().DebugString())
if length != 1 {
buf.WriteByte(',')
}
}
buf.WriteByte(']')
return buf.String()
}
// String outputs an ExtendedJSON version of Array. If the Array is not valid, this method
// returns an empty string.
func (a Array) String() string {
str, _ := a.StringN(-1)
return str
}
// StringN stringifies an array. If N is non-negative, it will truncate the string to N bytes.
// Otherwise, it will return the full string representation. The second return value indicates
// whether the string was truncated or not.
func (a Array) StringN(n int) (string, bool) {
length, rem, ok := ReadLength(a)
if !ok || length < 5 {
return "", false
}
length -= 4 // length bytes
length-- // final null byte
if n == 0 {
return "", true
}
var buf strings.Builder
buf.WriteByte('[')
var truncated bool
var elem Element
var str string
first := true
for length > 0 && !truncated {
needStrLen := -1
// Set needStrLen if n is positive, meaning we want to limit the string length.
if n > 0 {
// Stop stringifying if we reach the limit, that also ensures needStrLen is
// greater than 0 if we need to limit the length.
if buf.Len() >= n {
truncated = true
break
}
needStrLen = n - buf.Len()
}
// Append a comma if this is not the first element.
if !first {
buf.WriteByte(',')
// If we are truncating, we need to account for the comma in the length.
if needStrLen > 0 {
needStrLen--
if needStrLen == 0 {
truncated = true
break
}
}
}
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
// Exit on malformed element.
if !ok || length < 0 {
return "", false
}
// Delegate to StringN() on the element.
str, truncated = elem.Value().StringN(needStrLen)
buf.WriteString(str)
first = false
}
if n <= 0 || (buf.Len() < n && !truncated) {
buf.WriteByte(']')
} else {
truncated = true
}
return buf.String(), truncated
}
// Values returns this array as a slice of values. The returned slice will contain valid values.
// If the array is not valid, the values up to the invalid point will be returned along with an
// error.
func (a Array) Values() ([]Value, error) {
return values(a)
}
// Validate validates the array and ensures the elements contained within are valid.
func (a Array) Validate() error {
length, rem, ok := ReadLength(a)
if !ok {
return NewInsufficientBytesError(a, rem)
}
if int(length) > len(a) {
return NewArrayLengthError(int(length), len(a))
}
if a[length-1] != 0x00 {
return ErrMissingNull
}
length -= 4
var elem Element
var keyNum int64
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return NewInsufficientBytesError(a, rem)
}
// validate element
err := elem.Validate()
if err != nil {
return err
}
// validate keys increase numerically
if fmt.Sprint(keyNum) != elem.Key() {
return fmt.Errorf("array key %q is out of order or invalid", elem.Key())
}
keyNum++
}
if len(rem) < 1 || rem[0] != 0x00 {
return ErrMissingNull
}
return nil
}
@@ -0,0 +1,198 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"strconv"
)
// ArrayBuilder builds a bson array
type ArrayBuilder struct {
arr []byte
indexes []int32
keys []int
}
// NewArrayBuilder creates a new ArrayBuilder
func NewArrayBuilder() *ArrayBuilder {
return (&ArrayBuilder{}).startArray()
}
// startArray reserves the array's length and sets the index to where the length begins
func (a *ArrayBuilder) startArray() *ArrayBuilder {
var index int32
index, a.arr = AppendArrayStart(a.arr)
a.indexes = append(a.indexes, index)
a.keys = append(a.keys, 0)
return a
}
// Build updates the length of the array and index to the beginning of the documents length
// bytes, then returns the array (bson bytes)
func (a *ArrayBuilder) Build() Array {
lastIndex := len(a.indexes) - 1
lastKey := len(a.keys) - 1
a.arr, _ = AppendArrayEnd(a.arr, a.indexes[lastIndex])
a.indexes = a.indexes[:lastIndex]
a.keys = a.keys[:lastKey]
return a.arr
}
// incrementKey() increments the value keys and returns the key to be used to a.appendArray* functions
func (a *ArrayBuilder) incrementKey() string {
idx := len(a.keys) - 1
key := strconv.Itoa(a.keys[idx])
a.keys[idx]++
return key
}
// AppendInt32 will append i32 to ArrayBuilder.arr
func (a *ArrayBuilder) AppendInt32(i32 int32) *ArrayBuilder {
a.arr = AppendInt32Element(a.arr, a.incrementKey(), i32)
return a
}
// AppendDocument will append doc to ArrayBuilder.arr
func (a *ArrayBuilder) AppendDocument(doc []byte) *ArrayBuilder {
a.arr = AppendDocumentElement(a.arr, a.incrementKey(), doc)
return a
}
// AppendArray will append arr to ArrayBuilder.arr
func (a *ArrayBuilder) AppendArray(arr []byte) *ArrayBuilder {
a.arr = AppendArrayElement(a.arr, a.incrementKey(), arr)
return a
}
// AppendDouble will append f to ArrayBuilder.doc
func (a *ArrayBuilder) AppendDouble(f float64) *ArrayBuilder {
a.arr = AppendDoubleElement(a.arr, a.incrementKey(), f)
return a
}
// AppendString will append str to ArrayBuilder.doc
func (a *ArrayBuilder) AppendString(str string) *ArrayBuilder {
a.arr = AppendStringElement(a.arr, a.incrementKey(), str)
return a
}
// AppendObjectID will append oid to ArrayBuilder.doc
func (a *ArrayBuilder) AppendObjectID(oid objectID) *ArrayBuilder {
a.arr = AppendObjectIDElement(a.arr, a.incrementKey(), oid)
return a
}
// AppendBinary will append a BSON binary element using subtype, and
// b to a.arr
func (a *ArrayBuilder) AppendBinary(subtype byte, b []byte) *ArrayBuilder {
a.arr = AppendBinaryElement(a.arr, a.incrementKey(), subtype, b)
return a
}
// AppendUndefined will append a BSON undefined element using key to a.arr
func (a *ArrayBuilder) AppendUndefined() *ArrayBuilder {
a.arr = AppendUndefinedElement(a.arr, a.incrementKey())
return a
}
// AppendBoolean will append a boolean element using b to a.arr
func (a *ArrayBuilder) AppendBoolean(b bool) *ArrayBuilder {
a.arr = AppendBooleanElement(a.arr, a.incrementKey(), b)
return a
}
// AppendDateTime will append datetime element dt to a.arr
func (a *ArrayBuilder) AppendDateTime(dt int64) *ArrayBuilder {
a.arr = AppendDateTimeElement(a.arr, a.incrementKey(), dt)
return a
}
// AppendNull will append a null element to a.arr
func (a *ArrayBuilder) AppendNull() *ArrayBuilder {
a.arr = AppendNullElement(a.arr, a.incrementKey())
return a
}
// AppendRegex will append pattern and options to a.arr
func (a *ArrayBuilder) AppendRegex(pattern, options string) *ArrayBuilder {
a.arr = AppendRegexElement(a.arr, a.incrementKey(), pattern, options)
return a
}
// AppendDBPointer will append ns and oid to a.arr
func (a *ArrayBuilder) AppendDBPointer(ns string, oid objectID) *ArrayBuilder {
a.arr = AppendDBPointerElement(a.arr, a.incrementKey(), ns, oid)
return a
}
// AppendJavaScript will append js to a.arr
func (a *ArrayBuilder) AppendJavaScript(js string) *ArrayBuilder {
a.arr = AppendJavaScriptElement(a.arr, a.incrementKey(), js)
return a
}
// AppendSymbol will append symbol to a.arr
func (a *ArrayBuilder) AppendSymbol(symbol string) *ArrayBuilder {
a.arr = AppendSymbolElement(a.arr, a.incrementKey(), symbol)
return a
}
// AppendCodeWithScope will append code and scope to a.arr
func (a *ArrayBuilder) AppendCodeWithScope(code string, scope Document) *ArrayBuilder {
a.arr = AppendCodeWithScopeElement(a.arr, a.incrementKey(), code, scope)
return a
}
// AppendTimestamp will append t and i to a.arr
func (a *ArrayBuilder) AppendTimestamp(t, i uint32) *ArrayBuilder {
a.arr = AppendTimestampElement(a.arr, a.incrementKey(), t, i)
return a
}
// AppendInt64 will append i64 to a.arr
func (a *ArrayBuilder) AppendInt64(i64 int64) *ArrayBuilder {
a.arr = AppendInt64Element(a.arr, a.incrementKey(), i64)
return a
}
// AppendDecimal128 will append d128 to a.arr
func (a *ArrayBuilder) AppendDecimal128(high, low uint64) *ArrayBuilder {
a.arr = AppendDecimal128Element(a.arr, a.incrementKey(), high, low)
return a
}
// AppendMaxKey will append a max key element to a.arr
func (a *ArrayBuilder) AppendMaxKey() *ArrayBuilder {
a.arr = AppendMaxKeyElement(a.arr, a.incrementKey())
return a
}
// AppendMinKey will append a min key element to a.arr
func (a *ArrayBuilder) AppendMinKey() *ArrayBuilder {
a.arr = AppendMinKeyElement(a.arr, a.incrementKey())
return a
}
// AppendValue appends a BSON value to the array.
func (a *ArrayBuilder) AppendValue(val Value) *ArrayBuilder {
a.arr = AppendValueElement(a.arr, a.incrementKey(), val)
return a
}
// StartArray starts building an inline Array. After this document is completed,
// the user must call a.FinishArray
func (a *ArrayBuilder) StartArray() *ArrayBuilder {
a.arr = AppendHeader(a.arr, TypeArray, a.incrementKey())
a.startArray()
return a
}
// FinishArray builds the most recent array created
func (a *ArrayBuilder) FinishArray() *ArrayBuilder {
a.arr = a.Build()
return a
}
@@ -0,0 +1,184 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
// DocumentBuilder builds a bson document
type DocumentBuilder struct {
doc []byte
indexes []int32
}
// startDocument reserves the document's length and set the index to where the length begins
func (db *DocumentBuilder) startDocument() *DocumentBuilder {
var index int32
index, db.doc = AppendDocumentStart(db.doc)
db.indexes = append(db.indexes, index)
return db
}
// NewDocumentBuilder creates a new DocumentBuilder
func NewDocumentBuilder() *DocumentBuilder {
return (&DocumentBuilder{}).startDocument()
}
// Build updates the length of the document and index to the beginning of the documents length
// bytes, then returns the document (bson bytes)
func (db *DocumentBuilder) Build() Document {
last := len(db.indexes) - 1
db.doc, _ = AppendDocumentEnd(db.doc, db.indexes[last])
db.indexes = db.indexes[:last]
return db.doc
}
// AppendInt32 will append an int32 element using key and i32 to DocumentBuilder.doc
func (db *DocumentBuilder) AppendInt32(key string, i32 int32) *DocumentBuilder {
db.doc = AppendInt32Element(db.doc, key, i32)
return db
}
// AppendDocument will append a bson embedded document element using key
// and doc to DocumentBuilder.doc
func (db *DocumentBuilder) AppendDocument(key string, doc []byte) *DocumentBuilder {
db.doc = AppendDocumentElement(db.doc, key, doc)
return db
}
// AppendArray will append a bson array using key and arr to DocumentBuilder.doc
func (db *DocumentBuilder) AppendArray(key string, arr []byte) *DocumentBuilder {
db.doc = AppendHeader(db.doc, TypeArray, key)
db.doc = AppendArray(db.doc, arr)
return db
}
// AppendDouble will append a double element using key and f to DocumentBuilder.doc
func (db *DocumentBuilder) AppendDouble(key string, f float64) *DocumentBuilder {
db.doc = AppendDoubleElement(db.doc, key, f)
return db
}
// AppendString will append str to DocumentBuilder.doc with the given key
func (db *DocumentBuilder) AppendString(key string, str string) *DocumentBuilder {
db.doc = AppendStringElement(db.doc, key, str)
return db
}
// AppendObjectID will append oid to DocumentBuilder.doc with the given key
func (db *DocumentBuilder) AppendObjectID(key string, oid objectID) *DocumentBuilder {
db.doc = AppendObjectIDElement(db.doc, key, oid)
return db
}
// AppendBinary will append a BSON binary element using key, subtype, and
// b to db.doc
func (db *DocumentBuilder) AppendBinary(key string, subtype byte, b []byte) *DocumentBuilder {
db.doc = AppendBinaryElement(db.doc, key, subtype, b)
return db
}
// AppendUndefined will append a BSON undefined element using key to db.doc
func (db *DocumentBuilder) AppendUndefined(key string) *DocumentBuilder {
db.doc = AppendUndefinedElement(db.doc, key)
return db
}
// AppendBoolean will append a boolean element using key and b to db.doc
func (db *DocumentBuilder) AppendBoolean(key string, b bool) *DocumentBuilder {
db.doc = AppendBooleanElement(db.doc, key, b)
return db
}
// AppendDateTime will append a datetime element using key and dt to db.doc
func (db *DocumentBuilder) AppendDateTime(key string, dt int64) *DocumentBuilder {
db.doc = AppendDateTimeElement(db.doc, key, dt)
return db
}
// AppendNull will append a null element using key to db.doc
func (db *DocumentBuilder) AppendNull(key string) *DocumentBuilder {
db.doc = AppendNullElement(db.doc, key)
return db
}
// AppendRegex will append pattern and options using key to db.doc
func (db *DocumentBuilder) AppendRegex(key, pattern, options string) *DocumentBuilder {
db.doc = AppendRegexElement(db.doc, key, pattern, options)
return db
}
// AppendDBPointer will append ns and oid to using key to db.doc
func (db *DocumentBuilder) AppendDBPointer(key string, ns string, oid objectID) *DocumentBuilder {
db.doc = AppendDBPointerElement(db.doc, key, ns, oid)
return db
}
// AppendJavaScript will append js using the provided key to db.doc
func (db *DocumentBuilder) AppendJavaScript(key, js string) *DocumentBuilder {
db.doc = AppendJavaScriptElement(db.doc, key, js)
return db
}
// AppendSymbol will append a BSON symbol element using key and symbol db.doc
func (db *DocumentBuilder) AppendSymbol(key, symbol string) *DocumentBuilder {
db.doc = AppendSymbolElement(db.doc, key, symbol)
return db
}
// AppendCodeWithScope will append code and scope using key to db.doc
func (db *DocumentBuilder) AppendCodeWithScope(key string, code string, scope Document) *DocumentBuilder {
db.doc = AppendCodeWithScopeElement(db.doc, key, code, scope)
return db
}
// AppendTimestamp will append t and i to db.doc using provided key
func (db *DocumentBuilder) AppendTimestamp(key string, t, i uint32) *DocumentBuilder {
db.doc = AppendTimestampElement(db.doc, key, t, i)
return db
}
// AppendInt64 will append i64 to dst using key to db.doc
func (db *DocumentBuilder) AppendInt64(key string, i64 int64) *DocumentBuilder {
db.doc = AppendInt64Element(db.doc, key, i64)
return db
}
// AppendDecimal128 will append d128 to db.doc using provided key
func (db *DocumentBuilder) AppendDecimal128(key string, high, low uint64) *DocumentBuilder {
db.doc = AppendDecimal128Element(db.doc, key, high, low)
return db
}
// AppendMaxKey will append a max key element using key to db.doc
func (db *DocumentBuilder) AppendMaxKey(key string) *DocumentBuilder {
db.doc = AppendMaxKeyElement(db.doc, key)
return db
}
// AppendMinKey will append a min key element using key to db.doc
func (db *DocumentBuilder) AppendMinKey(key string) *DocumentBuilder {
db.doc = AppendMinKeyElement(db.doc, key)
return db
}
// AppendValue will append a BSON element with the provided key and value to the document.
func (db *DocumentBuilder) AppendValue(key string, val Value) *DocumentBuilder {
db.doc = AppendValueElement(db.doc, key, val)
return db
}
// StartDocument starts building an inline document element with the provided key
// After this document is completed, the user must call finishDocument
func (db *DocumentBuilder) StartDocument(key string) *DocumentBuilder {
db.doc = AppendHeader(db.doc, TypeEmbeddedDocument, key)
db = db.startDocument()
return db
}
// FinishDocument builds the most recent document created
func (db *DocumentBuilder) FinishDocument() *DocumentBuilder {
db.doc = db.Build()
return db
}
+773
View File
@@ -0,0 +1,773 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"strconv"
"strings"
"time"
"go.mongodb.org/mongo-driver/v2/internal/binaryutil"
)
const (
// EmptyDocumentLength is the length of a document that has been started/ended but has no elements.
EmptyDocumentLength = 5
// nullTerminator is a string version of the 0 byte that is appended at the end of cstrings.
nullTerminator = string(byte(0))
invalidKeyPanicMsg = "BSON element keys cannot contain null bytes"
invalidRegexPanicMsg = "BSON regex values cannot contain null bytes"
)
type objectID = [12]byte
// AppendType will append t to dst and return the extended buffer.
func AppendType(dst []byte, t Type) []byte { return append(dst, byte(t)) }
// AppendKey will append key to dst and return the extended buffer.
func AppendKey(dst []byte, key string) []byte { return append(dst, key+nullTerminator...) }
// AppendHeader will append Type t and key to dst and return the extended
// buffer.
func AppendHeader(dst []byte, t Type, key string) []byte {
if !isValidCString(key) {
panic(invalidKeyPanicMsg)
}
dst = AppendType(dst, t)
dst = append(dst, key...)
return append(dst, 0x00)
// return append(AppendType(dst, t), key+string(0x00)...)
}
// TODO(skriptble): All of the Read* functions should return src resliced to start just after what was read.
// ReadType will return the first byte of the provided []byte as a type. If
// there is no available byte, false is returned.
func ReadType(src []byte) (Type, []byte, bool) {
if len(src) < 1 {
return 0, src, false
}
return Type(src[0]), src[1:], true
}
// ReadKey will read a key from src. The 0x00 byte will not be present
// in the returned string. If there are not enough bytes available, false is
// returned.
func ReadKey(src []byte) (string, []byte, bool) { return binaryutil.ReadCString(src) }
// ReadKeyBytes will read a key from src as bytes. The 0x00 byte will
// not be present in the returned string. If there are not enough bytes
// available, false is returned.
func ReadKeyBytes(src []byte) ([]byte, []byte, bool) { return binaryutil.ReadCStringBytes(src) }
// ReadHeader will read a type byte and a key from src. If both of these
// values cannot be read, false is returned.
func ReadHeader(src []byte) (t Type, key string, rem []byte, ok bool) {
t, rem, ok = ReadType(src)
if !ok {
return 0, "", src, false
}
key, rem, ok = ReadKey(rem)
if !ok {
return 0, "", src, false
}
return t, key, rem, true
}
// ReadHeaderBytes will read a type and a key from src and the remainder of the bytes
// are returned as rem. If either the type or key cannot be red, ok will be false.
func ReadHeaderBytes(src []byte) (header []byte, rem []byte, ok bool) {
if len(src) < 1 {
return nil, src, false
}
idx := bytes.IndexByte(src[1:], 0x00)
if idx == -1 {
return nil, src, false
}
return src[:idx], src[idx+1:], true
}
// ReadElement reads the next full element from src. It returns the element, the remaining bytes in
// the slice, and a boolean indicating if the read was successful.
func ReadElement(src []byte) (Element, []byte, bool) {
if len(src) < 1 {
return nil, src, false
}
t := Type(src[0])
idx := 1
for idx < len(src) && src[idx] != 0x00 {
idx++
}
if idx >= len(src) {
return nil, src, false
}
idx++ // Move past the null byte
length, ok := valueLength(src[idx:], t)
if !ok {
return nil, src, false
}
elemLength := idx + int(length)
if elemLength > len(src) {
return nil, src, false
}
return src[:elemLength], src[elemLength:], true
}
// AppendValueElement appends value to dst as an element using key as the element's key.
func AppendValueElement(dst []byte, key string, value Value) []byte {
dst = AppendHeader(dst, value.Type, key)
dst = append(dst, value.Data...)
return dst
}
// ReadValue reads the next value as the provided types and returns a Value, the remaining bytes,
// and a boolean indicating if the read was successful.
func ReadValue(src []byte, t Type) (Value, []byte, bool) {
data, rem, ok := readValue(src, t)
if !ok {
return Value{}, src, false
}
return Value{Type: t, Data: data}, rem, true
}
// AppendDouble will append f to dst and return the extended buffer.
func AppendDouble(dst []byte, f float64) []byte {
return binaryutil.Append64(dst, math.Float64bits(f))
}
// AppendDoubleElement will append a BSON double element using key and f to dst
// and return the extended buffer.
func AppendDoubleElement(dst []byte, key string, f float64) []byte {
return AppendDouble(AppendHeader(dst, TypeDouble, key), f)
}
// ReadDouble will read a float64 from src. If there are not enough bytes it
// will return false.
func ReadDouble(src []byte) (float64, []byte, bool) {
bits, src, ok := binaryutil.ReadU64(src)
if !ok {
return 0, src, false
}
return math.Float64frombits(bits), src, true
}
// AppendString will append s to dst and return the extended buffer.
func AppendString(dst []byte, s string) []byte {
return appendstring(dst, s)
}
// AppendStringElement will append a BSON string element using key and val to dst
// and return the extended buffer.
func AppendStringElement(dst []byte, key, val string) []byte {
return AppendString(AppendHeader(dst, TypeString, key), val)
}
// ReadString will read a string from src. If there are not enough bytes it
// will return false.
func ReadString(src []byte) (string, []byte, bool) {
return readstring(src)
}
// AppendDocumentStart reserves a document's length and returns the index where the length begins.
// This index can later be used to write the length of the document.
func AppendDocumentStart(dst []byte) (index int32, b []byte) {
// TODO(skriptble): We really need AppendDocumentStart and AppendDocumentEnd. AppendDocumentStart would handle calling
// TODO ReserveLength and providing the index of the start of the document. AppendDocumentEnd would handle taking that
// TODO start index, adding the null byte, calculating the length, and filling in the length at the start of the
// TODO document.
return ReserveLength(dst)
}
// AppendDocumentStartInline functions the same as AppendDocumentStart but takes a pointer to the
// index int32 which allows this function to be used inline.
func AppendDocumentStartInline(dst []byte, index *int32) []byte {
idx, doc := AppendDocumentStart(dst)
*index = idx
return doc
}
// AppendDocumentElementStart writes a document element header and then reserves the length bytes.
func AppendDocumentElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendDocumentStart(AppendHeader(dst, TypeEmbeddedDocument, key))
}
// AppendDocumentEnd writes the null byte for a document and updates the length of the document.
// The index should be the beginning of the document's length bytes.
func AppendDocumentEnd(dst []byte, index int32) ([]byte, error) {
if int(index) > len(dst)-4 {
return dst, fmt.Errorf("not enough bytes available after index to write length")
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, index, int32(len(dst[index:])))
return dst, nil
}
// AppendDocument will append doc to dst and return the extended buffer.
func AppendDocument(dst []byte, doc []byte) []byte { return append(dst, doc...) }
// AppendDocumentElement will append a BSON embedded document element using key
// and doc to dst and return the extended buffer.
func AppendDocumentElement(dst []byte, key string, doc []byte) []byte {
return AppendDocument(AppendHeader(dst, TypeEmbeddedDocument, key), doc)
}
// BuildDocument will create a document with the given slice of elements and will append
// it to dst and return the extended buffer.
func BuildDocument(dst []byte, elems ...[]byte) []byte {
idx, dst := ReserveLength(dst)
for _, elem := range elems {
dst = append(dst, elem...)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildDocumentValue creates an Embedded Document value from the given elements.
func BuildDocumentValue(elems ...[]byte) Value {
return Value{Type: TypeEmbeddedDocument, Data: BuildDocument(nil, elems...)}
}
// BuildDocumentElement will append a BSON embedded document element using key and the provided
// elements and return the extended buffer.
func BuildDocumentElement(dst []byte, key string, elems ...[]byte) []byte {
return BuildDocument(AppendHeader(dst, TypeEmbeddedDocument, key), elems...)
}
// BuildDocumentFromElements is an alaias for the BuildDocument function.
var BuildDocumentFromElements = BuildDocument
// ReadDocument will read a document from src. If there are not enough bytes it
// will return false.
func ReadDocument(src []byte) (doc Document, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendArrayStart appends the length bytes to an array and then returns the index of the start
// of those length bytes.
func AppendArrayStart(dst []byte) (index int32, b []byte) { return ReserveLength(dst) }
// AppendArrayElementStart appends an array element header and then the length bytes for an array,
// returning the index where the length starts.
func AppendArrayElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendArrayStart(AppendHeader(dst, TypeArray, key))
}
// AppendArrayEnd appends the null byte to an array and calculates the length, inserting that
// calculated length starting at index.
func AppendArrayEnd(dst []byte, index int32) ([]byte, error) { return AppendDocumentEnd(dst, index) }
// AppendArray will append arr to dst and return the extended buffer.
func AppendArray(dst []byte, arr []byte) []byte { return append(dst, arr...) }
// AppendArrayElement will append a BSON array element using key and arr to dst
// and return the extended buffer.
func AppendArrayElement(dst []byte, key string, arr []byte) []byte {
return AppendArray(AppendHeader(dst, TypeArray, key), arr)
}
// BuildArray will append a BSON array to dst built from values.
func BuildArray(dst []byte, values ...Value) []byte {
idx, dst := ReserveLength(dst)
for pos, val := range values {
dst = AppendValueElement(dst, strconv.Itoa(pos), val)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildArrayElement will create an array element using the provided values.
func BuildArrayElement(dst []byte, key string, values ...Value) []byte {
return BuildArray(AppendHeader(dst, TypeArray, key), values...)
}
// ReadArray will read an array from src. If there are not enough bytes it
// will return false.
func ReadArray(src []byte) (arr Array, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendBinary will append subtype and b to dst and return the extended buffer.
func AppendBinary(dst []byte, subtype byte, b []byte) []byte {
if subtype == 0x02 {
return appendBinarySubtype2(dst, subtype, b)
}
dst = append(appendLength(dst, int32(len(b))), subtype)
return append(dst, b...)
}
// AppendBinaryElement will append a BSON binary element using key, subtype, and
// b to dst and return the extended buffer.
func AppendBinaryElement(dst []byte, key string, subtype byte, b []byte) []byte {
return AppendBinary(AppendHeader(dst, TypeBinary, key), subtype, b)
}
// ReadBinary will read a subtype and bin from src. If there are not enough bytes it
// will return false.
func ReadBinary(src []byte) (subtype byte, bin []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok {
return 0x00, nil, src, false
}
if len(rem) < 1 { // subtype
return 0x00, nil, src, false
}
subtype, rem = rem[0], rem[1:]
if len(rem) < int(length) {
return 0x00, nil, src, false
}
if subtype == 0x02 {
length, rem, ok = ReadLength(rem)
if !ok || len(rem) < int(length) {
return 0x00, nil, src, false
}
}
return subtype, rem[:length], rem[length:], true
}
// AppendUndefinedElement will append a BSON undefined element using key to dst
// and return the extended buffer.
func AppendUndefinedElement(dst []byte, key string) []byte {
return AppendHeader(dst, TypeUndefined, key)
}
// AppendObjectID will append oid to dst and return the extended buffer.
func AppendObjectID(dst []byte, oid objectID) []byte { return append(dst, oid[:]...) }
// AppendObjectIDElement will append a BSON ObjectID element using key and oid to dst
// and return the extended buffer.
func AppendObjectIDElement(dst []byte, key string, oid objectID) []byte {
return AppendObjectID(AppendHeader(dst, TypeObjectID, key), oid)
}
// ReadObjectID will read an ObjectID from src. If there are not enough bytes it
// will return false.
func ReadObjectID(src []byte) ([12]byte, []byte, bool) {
var oid objectID
idLen := cap(oid)
if len(src) < idLen {
return oid, src, false
}
copy(oid[:], src[0:idLen])
return oid, src[idLen:], true
}
// AppendBoolean will append b to dst and return the extended buffer.
func AppendBoolean(dst []byte, b bool) []byte {
if b {
return append(dst, 0x01)
}
return append(dst, 0x00)
}
// AppendBooleanElement will append a BSON boolean element using key and b to dst
// and return the extended buffer.
func AppendBooleanElement(dst []byte, key string, b bool) []byte {
return AppendBoolean(AppendHeader(dst, TypeBoolean, key), b)
}
// ReadBoolean will read a bool from src. If there are not enough bytes it
// will return false.
func ReadBoolean(src []byte) (bool, []byte, bool) {
if len(src) < 1 {
return false, src, false
}
return src[0] == 0x01, src[1:], true
}
// AppendDateTime will append dt to dst and return the extended buffer.
func AppendDateTime(dst []byte, dt int64) []byte { return binaryutil.Append64(dst, dt) }
// AppendDateTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendDateTimeElement(dst []byte, key string, dt int64) []byte {
return AppendDateTime(AppendHeader(dst, TypeDateTime, key), dt)
}
// ReadDateTime will read an int64 datetime from src. If there are not enough bytes it
// will return false.
func ReadDateTime(src []byte) (int64, []byte, bool) { return binaryutil.ReadI64(src) }
// AppendTime will append time as a BSON DateTime to dst and return the extended buffer.
func AppendTime(dst []byte, t time.Time) []byte {
return AppendDateTime(dst, t.Unix()*1000+int64(t.Nanosecond()/1e6))
}
// AppendTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendTimeElement(dst []byte, key string, t time.Time) []byte {
return AppendTime(AppendHeader(dst, TypeDateTime, key), t)
}
// ReadTime will read an time.Time datetime from src. If there are not enough bytes it
// will return false.
func ReadTime(src []byte) (time.Time, []byte, bool) {
dt, rem, ok := binaryutil.ReadI64(src)
return time.Unix(dt/1e3, dt%1e3*1e6), rem, ok
}
// AppendNullElement will append a BSON null element using key to dst
// and return the extended buffer.
func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst, TypeNull, key) }
// AppendRegex will append pattern and options to dst and return the extended buffer.
func AppendRegex(dst []byte, pattern, options string) []byte {
if !isValidCString(pattern) || !isValidCString(options) {
panic(invalidRegexPanicMsg)
}
return append(dst, pattern+nullTerminator+options+nullTerminator...)
}
// AppendRegexElement will append a BSON regex element using key, pattern, and
// options to dst and return the extended buffer.
func AppendRegexElement(dst []byte, key, pattern, options string) []byte {
return AppendRegex(AppendHeader(dst, TypeRegex, key), pattern, options)
}
// ReadRegex will read a pattern and options from src. If there are not enough bytes it
// will return false.
func ReadRegex(src []byte) (pattern, options string, rem []byte, ok bool) {
pattern, rem, ok = binaryutil.ReadCString(src)
if !ok {
return "", "", src, false
}
options, rem, ok = binaryutil.ReadCString(rem)
if !ok {
return "", "", src, false
}
return pattern, options, rem, true
}
// AppendDBPointer will append ns and oid to dst and return the extended buffer.
func AppendDBPointer(dst []byte, ns string, oid objectID) []byte {
return append(appendstring(dst, ns), oid[:]...)
}
// AppendDBPointerElement will append a BSON DBPointer element using key, ns,
// and oid to dst and return the extended buffer.
func AppendDBPointerElement(dst []byte, key, ns string, oid objectID) []byte {
return AppendDBPointer(AppendHeader(dst, TypeDBPointer, key), ns, oid)
}
// ReadDBPointer will read a ns and oid from src. If there are not enough bytes it
// will return false.
func ReadDBPointer(src []byte) (ns string, oid [12]byte, rem []byte, ok bool) {
ns, rem, ok = readstring(src)
if !ok {
return "", objectID{}, src, false
}
oid, rem, ok = ReadObjectID(rem)
if !ok {
return "", objectID{}, src, false
}
return ns, oid, rem, true
}
// AppendJavaScript will append js to dst and return the extended buffer.
func AppendJavaScript(dst []byte, js string) []byte { return appendstring(dst, js) }
// AppendJavaScriptElement will append a BSON JavaScript element using key and
// js to dst and return the extended buffer.
func AppendJavaScriptElement(dst []byte, key, js string) []byte {
return AppendJavaScript(AppendHeader(dst, TypeJavaScript, key), js)
}
// ReadJavaScript will read a js string from src. If there are not enough bytes it
// will return false.
func ReadJavaScript(src []byte) (js string, rem []byte, ok bool) { return readstring(src) }
// AppendSymbol will append symbol to dst and return the extended buffer.
func AppendSymbol(dst []byte, symbol string) []byte { return appendstring(dst, symbol) }
// AppendSymbolElement will append a BSON symbol element using key and symbol to dst
// and return the extended buffer.
func AppendSymbolElement(dst []byte, key, symbol string) []byte {
return AppendSymbol(AppendHeader(dst, TypeSymbol, key), symbol)
}
// ReadSymbol will read a symbol string from src. If there are not enough bytes it
// will return false.
func ReadSymbol(src []byte) (symbol string, rem []byte, ok bool) { return readstring(src) }
// AppendCodeWithScope will append code and scope to dst and return the extended buffer.
func AppendCodeWithScope(dst []byte, code string, scope []byte) []byte {
length := int32(4 + 4 + len(code) + 1 + len(scope)) // length of cws, length of code, code, 0x00, scope
dst = appendLength(dst, length)
return append(appendstring(dst, code), scope...)
}
// AppendCodeWithScopeElement will append a BSON code with scope element using
// key, code, and scope to dst
// and return the extended buffer.
func AppendCodeWithScopeElement(dst []byte, key, code string, scope []byte) []byte {
return AppendCodeWithScope(AppendHeader(dst, TypeCodeWithScope, key), code, scope)
}
// ReadCodeWithScope will read code and scope from src. If there are not enough bytes it
// will return false.
func ReadCodeWithScope(src []byte) (code string, scope []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok || len(src) < int(length) {
return "", nil, src, false
}
code, rem, ok = readstring(rem)
if !ok {
return "", nil, src, false
}
scope, rem, ok = ReadDocument(rem)
if !ok {
return "", nil, src, false
}
return code, scope, rem, true
}
// AppendInt32 will append i32 to dst and return the extended buffer.
func AppendInt32(dst []byte, i32 int32) []byte { return binaryutil.Append32(dst, i32) }
// AppendInt32Element will append a BSON int32 element using key and i32 to dst
// and return the extended buffer.
func AppendInt32Element(dst []byte, key string, i32 int32) []byte {
return AppendInt32(AppendHeader(dst, TypeInt32, key), i32)
}
// ReadInt32 will read an int32 from src. If there are not enough bytes it
// will return false.
func ReadInt32(src []byte) (int32, []byte, bool) { return binaryutil.ReadI32(src) }
// AppendTimestamp will append t and i to dst and return the extended buffer.
func AppendTimestamp(dst []byte, t, i uint32) []byte {
return binaryutil.Append32(binaryutil.Append32(dst, i), t) // i is the lower 4 bytes, t is the higher 4 bytes
}
// AppendTimestampElement will append a BSON timestamp element using key, t, and
// i to dst and return the extended buffer.
func AppendTimestampElement(dst []byte, key string, t, i uint32) []byte {
return AppendTimestamp(AppendHeader(dst, TypeTimestamp, key), t, i)
}
// ReadTimestamp will read t and i from src. If there are not enough bytes it
// will return false.
func ReadTimestamp(src []byte) (t, i uint32, rem []byte, ok bool) {
i, rem, ok = binaryutil.ReadU32(src)
if !ok {
return 0, 0, src, false
}
t, rem, ok = binaryutil.ReadU32(rem)
if !ok {
return 0, 0, src, false
}
return t, i, rem, true
}
// AppendInt64 will append i64 to dst and return the extended buffer.
func AppendInt64(dst []byte, i64 int64) []byte { return binaryutil.Append64(dst, i64) }
// AppendInt64Element will append a BSON int64 element using key and i64 to dst
// and return the extended buffer.
func AppendInt64Element(dst []byte, key string, i64 int64) []byte {
return AppendInt64(AppendHeader(dst, TypeInt64, key), i64)
}
// ReadInt64 will read an int64 from src. If there are not enough bytes it
// will return false.
func ReadInt64(src []byte) (int64, []byte, bool) { return binaryutil.ReadI64(src) }
// AppendDecimal128 will append high and low parts of a d128 to dst and return the extended buffer.
func AppendDecimal128(dst []byte, high, low uint64) []byte {
return binaryutil.Append64(binaryutil.Append64(dst, low), high)
}
// AppendDecimal128Element will append high and low parts of a BSON bson.Decimal128 element using key and
// d128 to dst and return the extended buffer.
func AppendDecimal128Element(dst []byte, key string, high, low uint64) []byte {
return AppendDecimal128(AppendHeader(dst, TypeDecimal128, key), high, low)
}
// ReadDecimal128 will read high and low parts of a bson.Decimal128 from src. If there are not enough bytes it
// will return false.
func ReadDecimal128(src []byte) (high uint64, low uint64, rem []byte, ok bool) {
low, rem, ok = binaryutil.ReadU64(src)
if !ok {
return 0, 0, src, false
}
high, rem, ok = binaryutil.ReadU64(rem)
if !ok {
return 0, 0, src, false
}
return high, low, rem, true
}
// AppendMaxKeyElement will append a BSON max key element using key to dst
// and return the extended buffer.
func AppendMaxKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, TypeMaxKey, key)
}
// AppendMinKeyElement will append a BSON min key element using key to dst
// and return the extended buffer.
func AppendMinKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, TypeMinKey, key)
}
// EqualValue will return true if the two values are equal.
func EqualValue(t1, t2 Type, v1, v2 []byte) bool {
if t1 != t2 {
return false
}
v1, _, ok := readValue(v1, t1)
if !ok {
return false
}
v2, _, ok = readValue(v2, t2)
if !ok {
return false
}
return bytes.Equal(v1, v2)
}
// valueLength will determine the length of the next value contained in src as if it
// is type t. The returned bool will be false if there are not enough bytes in src for
// a value of type t.
func valueLength(src []byte, t Type) (int32, bool) {
var length int32
ok := true
switch t {
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
length, _, ok = ReadLength(src)
case TypeBinary:
length, _, ok = ReadLength(src)
length += 4 + 1 // binary length + subtype byte
case TypeBoolean:
length = 1
case TypeDBPointer:
length, _, ok = ReadLength(src)
length += 4 + 12 // string length + ObjectID length
case TypeDateTime, TypeDouble, TypeInt64, TypeTimestamp:
length = 8
case TypeDecimal128:
length = 16
case TypeInt32:
length = 4
case TypeJavaScript, TypeString, TypeSymbol:
length, _, ok = ReadLength(src)
length += 4
case TypeMaxKey, TypeMinKey, TypeNull, TypeUndefined:
length = 0
case TypeObjectID:
length = 12
case TypeRegex:
regex := bytes.IndexByte(src, 0x00)
if regex < 0 {
ok = false
break
}
pattern := bytes.IndexByte(src[regex+1:], 0x00)
if pattern < 0 {
ok = false
break
}
length = int32(int64(regex) + 1 + int64(pattern) + 1)
default:
ok = false
}
return length, ok
}
func readValue(src []byte, t Type) ([]byte, []byte, bool) {
length, ok := valueLength(src, t)
if !ok || int(length) > len(src) {
return nil, src, false
}
return src[:length], src[length:], true
}
// ReserveLength reserves the space required for length and returns the index where to write the length
// and the []byte with reserved space.
func ReserveLength(dst []byte) (int32, []byte) {
index := len(dst)
return int32(index), append(dst, 0x00, 0x00, 0x00, 0x00)
}
// UpdateLength updates the length at index with length and returns the []byte.
func UpdateLength(dst []byte, index, length int32) []byte {
binary.LittleEndian.PutUint32(dst[index:], uint32(length))
return dst
}
func appendLength(dst []byte, l int32) []byte { return binaryutil.Append32(dst, l) }
// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
// there aren't enough bytes to read a valid length, src is returned unomdified and the returned
// bool will be false.
func ReadLength(src []byte) (int32, []byte, bool) {
ln, src, ok := binaryutil.ReadI32(src)
if ln < 0 {
return ln, src, false
}
return ln, src, ok
}
func appendstring(dst []byte, s string) []byte {
l := int32(len(s) + 1)
dst = appendLength(dst, l)
dst = append(dst, s...)
return append(dst, 0x00)
}
func readstring(src []byte) (string, []byte, bool) {
l, rem, ok := ReadLength(src)
if !ok {
return "", src, false
}
if len(src[4:]) < int(l) || l == 0 {
return "", src, false
}
return string(rem[:l-1]), rem[l:], true
}
// readLengthBytes attempts to read a length and that number of bytes. This
// function requires that the length include the four bytes for itself.
func readLengthBytes(src []byte) ([]byte, []byte, bool) {
l, _, ok := ReadLength(src)
if !ok {
return nil, src, false
}
if l < 4 {
return nil, src, false
}
if len(src) < int(l) {
return nil, src, false
}
return src[:l], src[l:], true
}
func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte {
dst = appendLength(dst, int32(len(b)+4)) // The bytes we'll encode need to be 4 larger for the length bytes
dst = append(dst, subtype)
dst = appendLength(dst, int32(len(b)))
return append(dst, b...)
}
func isValidCString(cs string) bool {
return !strings.ContainsRune(cs, '\x00')
}
+34
View File
@@ -0,0 +1,34 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsoncore is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
//
// Package bsoncore contains functions that can be used to encode and decode
// BSON elements and values to or from a slice of bytes. These functions are
// aimed at allowing low level manipulation of BSON and can be used to build a
// higher level BSON library.
//
// The Read* functions within this package return the values of the element and
// a boolean indicating if the values are valid. A boolean was used instead of
// an error because any error that would be returned would be the same: not
// enough bytes. This library attempts to do no validation, it will only return
// false if there are not enough bytes for an item to be read. For example, the
// ReadDocument function checks the length, if that length is larger than the
// number of bytes available, it will return false, if there are enough bytes,
// it will return those bytes and true. It is the consumers responsibility to
// validate those bytes.
//
// The Append* functions within this package will append the type value to the
// given dst slice. If the slice has enough capacity, it will not grow the
// slice. The Append*Element functions within this package operate in the same
// way, but additionally append the BSON type and the key before the value.
package bsoncore
+431
View File
@@ -0,0 +1,431 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"errors"
"fmt"
"io"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/v2/internal/binaryutil"
)
// ValidationError is an error type returned when attempting to validate a document or array.
type ValidationError string
func (ve ValidationError) Error() string { return string(ve) }
// NewDocumentLengthError creates and returns an error for when the length of a document exceeds the
// bytes available.
func NewDocumentLengthError(length, rem int) error {
return lengthError("document", length, rem)
}
func lengthError(bufferType string, length, rem int) error {
return ValidationError(fmt.Sprintf("%v length exceeds available bytes. length=%d remainingBytes=%d",
bufferType, length, rem))
}
// InsufficientBytesError indicates that there were not enough bytes to read the next component.
type InsufficientBytesError struct {
Source []byte
Remaining []byte
}
// NewInsufficientBytesError creates a new InsufficientBytesError with the given Document and
// remaining bytes.
func NewInsufficientBytesError(src, rem []byte) InsufficientBytesError {
return InsufficientBytesError{Source: src, Remaining: rem}
}
// Error implements the error interface.
func (ibe InsufficientBytesError) Error() string {
return "too few bytes to read next component"
}
// Equal checks that err2 also is an ErrTooSmall.
func (ibe InsufficientBytesError) Equal(err2 error) bool {
switch err2.(type) {
case InsufficientBytesError:
return true
default:
return false
}
}
// InvalidDepthTraversalError is returned when attempting a recursive Lookup when one component of
// the path is neither an embedded document nor an array.
type InvalidDepthTraversalError struct {
Key string
Type Type
}
func (idte InvalidDepthTraversalError) Error() string {
return fmt.Sprintf(
"attempt to traverse into %s, but it's type is %s, not %s nor %s",
idte.Key, idte.Type, TypeEmbeddedDocument, TypeArray,
)
}
// ErrMissingNull is returned when a document or array's last byte is not null.
const ErrMissingNull ValidationError = "document or array end is missing null byte"
// ErrInvalidLength indicates that a length in a binary representation of a BSON document or array
// is invalid.
const ErrInvalidLength ValidationError = "document or array length is invalid"
// ErrNilReader indicates that an operation was attempted on a nil io.Reader.
var ErrNilReader = errors.New("nil reader")
// ErrEmptyKey indicates that no key was provided to a Lookup method.
var ErrEmptyKey = errors.New("empty key provided")
// ErrElementNotFound indicates that an Element matching a certain condition does not exist.
var ErrElementNotFound = errors.New("element not found")
// ErrOutOfBounds indicates that an index provided to access something was invalid.
var ErrOutOfBounds = errors.New("out of bounds")
// Document is a raw bytes representation of a BSON document.
type Document []byte
// NewDocumentFromReader reads a document from r. This function will only validate the length is
// correct and that the document ends with a null byte.
func NewDocumentFromReader(r io.Reader) (Document, error) {
return newBufferFromReader(r)
}
func newBufferFromReader(r io.Reader) ([]byte, error) {
if r == nil {
return nil, ErrNilReader
}
var lengthBytes [4]byte
// ReadFull guarantees that we will have read at least len(lengthBytes) if err == nil
_, err := io.ReadFull(r, lengthBytes[:])
if err != nil {
return nil, err
}
length, _, _ := binaryutil.ReadI32(lengthBytes[:]) // ignore ok since we always have enough bytes to read a length
if length < 0 {
return nil, ErrInvalidLength
}
buffer := make([]byte, length)
copy(buffer, lengthBytes[:])
_, err = io.ReadFull(r, buffer[4:])
if err != nil {
return nil, err
}
if buffer[length-1] != 0x00 {
return nil, ErrMissingNull
}
return buffer, nil
}
// Lookup searches the document, potentially recursively, for the given key. If there are multiple
// keys provided, this method will recurse down, as long as the top and intermediate nodes are
// either documents or arrays. If an error occurs or if the value doesn't exist, an empty Value is
// returned.
func (d Document) Lookup(key ...string) Value {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr is the same as Lookup, except it returns an error in addition to an empty Value.
func (d Document) LookupErr(key ...string) (Value, error) {
if len(key) < 1 {
return Value{}, ErrEmptyKey
}
length, rem, ok := ReadLength(d)
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
// We use `KeyBytes` rather than `Key` to avoid a needless string alloc.
if string(elem.KeyBytes()) != key[0] {
continue
}
if len(key) > 1 {
tt := Type(elem[0])
switch tt {
case TypeEmbeddedDocument:
val, err := elem.Value().Document().LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
case TypeArray:
// Convert to Document to continue Lookup recursion.
val, err := Document(elem.Value().Array()).LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
default:
return Value{}, InvalidDepthTraversalError{Key: elem.Key(), Type: tt}
}
}
return elem.ValueErr()
}
return Value{}, ErrElementNotFound
}
// Index searches for and retrieves the element at the given index. This method will panic if
// the document is invalid or if the index is out of bounds.
func (d Document) Index(index uint) Element {
elem, err := d.IndexErr(index)
if err != nil {
panic(err)
}
return elem
}
// IndexErr searches for and retrieves the element at the given index.
func (d Document) IndexErr(index uint) (Element, error) {
return indexErr(d, index)
}
func indexErr(b []byte, index uint) (Element, error) {
length, rem, ok := ReadLength(b)
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
length -= 4
var current uint
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
if current != index {
current++
continue
}
return elem, nil
}
return nil, ErrOutOfBounds
}
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (d Document) DebugString() string {
if len(d) < 5 {
return "<malformed>"
}
var buf strings.Builder
buf.WriteString("Document")
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
buf.WriteByte('(')
buf.WriteString(strconv.Itoa(int(length)))
length -= 4
buf.WriteString("){")
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
break
}
buf.WriteString(elem.DebugString())
}
buf.WriteByte('}')
return buf.String()
}
// String outputs an ExtendedJSON version of Document. If the document is not valid, this method
// returns an empty string.
func (d Document) String() string {
str, _ := d.StringN(-1)
return str
}
// StringN stringifies a document. If N is non-negative, it will truncate the string to N bytes.
// Otherwise, it will return the full string representation. The second return value indicates
// whether the string was truncated or not.
func (d Document) StringN(n int) (string, bool) {
length, rem, ok := ReadLength(d)
if !ok || length < 5 {
return "", false
}
length -= 4 // length bytes
length-- // final null byte
if n == 0 {
return "", true
}
var buf strings.Builder
buf.WriteByte('{')
var truncated bool
var elem Element
var str string
first := true
for length > 0 && !truncated {
needStrLen := -1
// Set needStrLen if n is positive, meaning we want to limit the string length.
if n > 0 {
// Stop stringifying if we reach the limit, that also ensures needStrLen is
// greater than 0 if we need to limit the length.
if buf.Len() >= n {
truncated = true
break
}
needStrLen = n - buf.Len()
}
// Append a comma if this is not the first element.
if !first {
buf.WriteByte(',')
// If we are truncating, we need to account for the comma in the length.
if needStrLen > 0 {
needStrLen--
if needStrLen == 0 {
truncated = true
break
}
}
}
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
// Exit on malformed element.
if !ok || length < 0 {
return "", false
}
// Delegate to StringN() on the element.
str, truncated = elem.StringN(needStrLen)
buf.WriteString(str)
first = false
}
if n <= 0 || (buf.Len() < n && !truncated) {
buf.WriteByte('}')
} else {
truncated = true
}
return buf.String(), truncated
}
// Elements returns this document as a slice of elements. The returned slice will contain valid
// elements. If the document is not valid, the elements up to the invalid point will be returned
// along with an error.
func (d Document) Elements() ([]Element, error) {
length, rem, ok := ReadLength(d)
if !ok {
return nil, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
var elems []Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return elems, NewInsufficientBytesError(d, rem)
}
if err := elem.Validate(); err != nil {
return elems, err
}
elems = append(elems, elem)
}
return elems, nil
}
// Values returns this document as a slice of values. The returned slice will contain valid values.
// If the document is not valid, the values up to the invalid point will be returned along with an
// error.
func (d Document) Values() ([]Value, error) {
return values(d)
}
func values(b []byte) ([]Value, error) {
length, rem, ok := ReadLength(b)
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
length -= 4
var elem Element
var vals []Value
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return vals, NewInsufficientBytesError(b, rem)
}
if err := elem.Value().Validate(); err != nil {
return vals, err
}
vals = append(vals, elem.Value())
}
return vals, nil
}
// Validate validates the document and ensures the elements contained within are valid.
func (d Document) Validate() error {
length, rem, ok := ReadLength(d)
if !ok {
return NewInsufficientBytesError(d, rem)
}
if int(length) > len(d) {
return NewDocumentLengthError(int(length), len(d))
}
if d[length-1] != 0x00 {
return ErrMissingNull
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return NewInsufficientBytesError(d, rem)
}
err := elem.Validate()
if err != nil {
return err
}
}
if len(rem) < 1 || rem[0] != 0x00 {
return ErrMissingNull
}
return nil
}
+213
View File
@@ -0,0 +1,213 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/v2/internal/bsoncoreutil"
)
// MalformedElementError represents a class of errors that RawElement methods return.
type MalformedElementError string
func (mee MalformedElementError) Error() string { return string(mee) }
// ErrElementMissingKey is returned when a RawElement is missing a key.
const ErrElementMissingKey MalformedElementError = "element is missing key"
// ErrElementMissingType is returned when a RawElement is missing a type.
const ErrElementMissingType MalformedElementError = "element is missing type"
// Element is a raw bytes representation of a BSON element.
type Element []byte
// Key returns the key for this element. If the element is not valid, this method returns an empty
// string. If knowing if the element is valid is important, use KeyErr.
func (e Element) Key() string {
key, _ := e.KeyErr()
return key
}
// KeyBytes returns the key for this element as a []byte. If the element is not valid, this method
// returns an empty string. If knowing if the element is valid is important, use KeyErr. This method
// will not include the null byte at the end of the key in the slice of bytes.
func (e Element) KeyBytes() []byte {
key, _ := e.KeyBytesErr()
return key
}
// KeyErr returns the key for this element, returning an error if the element is not valid.
func (e Element) KeyErr() (string, error) {
key, err := e.KeyBytesErr()
return string(key), err
}
// KeyBytesErr returns the key for this element as a []byte, returning an error if the element is
// not valid.
func (e Element) KeyBytesErr() ([]byte, error) {
if len(e) == 0 {
return nil, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return nil, ErrElementMissingKey
}
return e[1 : idx+1], nil
}
// Validate ensures the element is a valid BSON element.
func (e Element) Validate() error {
if len(e) < 1 {
return ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return ErrElementMissingKey
}
return Value{Type: Type(e[0]), Data: e[idx+2:]}.Validate()
}
// CompareKey will compare this element's key to key. This method makes it easy to compare keys
// without needing to allocate a string. The key may be null terminated. If a valid key cannot be
// read this method will return false.
func (e Element) CompareKey(key []byte) bool {
if len(e) < 2 {
return false
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return false
}
if index := bytes.IndexByte(key, 0x00); index > -1 {
key = key[:index]
}
return bytes.Equal(e[1:idx+1], key)
}
// Value returns the value of this element. If the element is not valid, this method returns an
// empty Value. If knowing if the element is valid is important, use ValueErr.
func (e Element) Value() Value {
val, _ := e.ValueErr()
return val
}
// ValueErr returns the value for this element, returning an error if the element is not valid.
func (e Element) ValueErr() (Value, error) {
if len(e) == 0 {
return Value{}, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return Value{}, ErrElementMissingKey
}
val, rem, exists := ReadValue(e[idx+2:], Type(e[0]))
if !exists {
return Value{}, NewInsufficientBytesError(e, rem)
}
return val, nil
}
// String implements the fmt.String interface. The output will be in extended JSON format.
func (e Element) String() string {
str, _ := e.StringN(-1)
return str
}
// StringN will return values in extended JSON format that will stringify an element upto N bytes.
// If N is non-negative, it will truncate the string to N bytes. Otherwise, it will return the full
// string representation. The second return value indicates whether the string was truncated or not.
// If the element is not valid, this returns an empty string
func (e Element) StringN(n int) (string, bool) {
if len(e) == 0 {
return "", false
}
if n == 0 {
return "", true
}
if n == 1 {
return `"`, true
}
t := Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx <= 0 {
return "", false
}
key := e[1 : idx+1]
var buf strings.Builder
buf.WriteByte('"')
const suffix = `": `
switch {
case n < 0 || idx <= n-buf.Len()-len(suffix):
buf.Write(key)
buf.WriteString(suffix)
case idx < n:
buf.Write(key)
buf.WriteString(suffix[:n-idx-1])
return buf.String(), true
default:
buf.WriteString(bsoncoreutil.Truncate(string(key), n-1))
return buf.String(), true
}
needStrLen := -1
// Set needStrLen if n is positive, meaning we want to limit the string length.
if n > 0 {
// Stop stringifying if we reach the limit, that also ensures needStrLen is
// greater than 0 if we need to limit the length.
if buf.Len() >= n {
return buf.String(), true
}
needStrLen = n - buf.Len()
}
val, _, valid := ReadValue(e[idx+2:], t)
if !valid {
return "", false
}
var str string
var truncated bool
if _, ok := val.StringValueOK(); ok {
str, truncated = val.StringN(needStrLen)
} else if arr, ok := val.ArrayOK(); ok {
str, truncated = arr.StringN(needStrLen)
} else {
str = val.String()
if needStrLen > 0 && len(str) > needStrLen {
truncated = true
str = bsoncoreutil.Truncate(str, needStrLen)
}
}
buf.WriteString(str)
return buf.String(), truncated
}
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
// valid components of the element even if the entire element is not valid.
func (e Element) DebugString() string {
if len(e) == 0 {
return "<malformed>"
}
t := Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return fmt.Sprintf(`bson.Element{[%s]<malformed>}`, t)
}
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
val, _, valid := ReadValue(valBytes, t)
if !valid {
return fmt.Sprintf(`bson.Element{[%s]"%s": <malformed>}`, t, key)
}
return fmt.Sprintf(`bson.Element{[%s]"%s": %v}`, t, key, val)
}
+113
View File
@@ -0,0 +1,113 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"errors"
"fmt"
"io"
)
// errCorruptedDocument is returned when a full document couldn't be read from
// the sequence.
var errCorruptedDocument = errors.New("invalid DocumentSequence: corrupted document")
// Iterator maintains a list of BSON values and keeps track of the current
// position in relation to its Next() method.
type Iterator struct {
List Array // List of BSON values
pos int // The position of the iterator in the list in reference to Next()
}
// Count returned the number of elements in the iterator's list.
func (iter *Iterator) Count() int {
if iter == nil {
return 0
}
_, rem, ok := ReadLength(iter.List)
if !ok {
return 0
}
var count int
for len(rem) > 1 {
_, rem, ok = ReadElement(rem)
if !ok {
return 0
}
count++
}
return count
}
// Empty returns true if the iterator's list is empty.
func (iter *Iterator) Empty() bool {
return len(iter.List) <= 5
}
// Reset will reset the iteration point for the Next method to the beginning of
// the list.
func (iter *Iterator) Reset() {
iter.pos = 0
}
// Documents traverses the list as documents and returns them. This method
// assumes that the underlying list is composed of documents and will return
// an error otherwise.
func (iter *Iterator) Documents() ([]Document, error) {
if iter == nil || len(iter.List) == 0 {
return nil, nil
}
vals, err := iter.List.Values()
if err != nil {
return nil, errCorruptedDocument
}
docs := make([]Document, 0, len(vals))
for _, v := range vals {
if v.Type != TypeEmbeddedDocument {
return nil, fmt.Errorf("invalid DocumentSequence: a non-document value was found in sequence")
}
docs = append(docs, v.Data)
}
return docs, nil
}
// Next retrieves the next value from the list and returns it. This method will
// return io.EOF when it has reached the end of the list.
func (iter *Iterator) Next() (*Value, error) {
if iter == nil || iter.pos >= len(iter.List) {
return nil, io.EOF
}
if iter.pos < 4 {
if len(iter.List) < 4 {
return nil, errCorruptedDocument
}
iter.pos = 4 // Skip the length of the document
}
rem := iter.List[iter.pos:]
if len(rem) == 1 && rem[0] == 0x00 {
return nil, io.EOF // At the end of the document
}
elem, _, ok := ReadElement(rem)
if !ok {
return nil, errCorruptedDocument
}
iter.pos += len(elem)
val := elem.Value()
return &val, nil
}
+223
View File
@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bsoncore
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
+85
View File
@@ -0,0 +1,85 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
// Type represents a BSON type.
type Type byte
// String returns the string representation of the BSON type's name.
func (bt Type) String() string {
switch bt {
case '\x01':
return "double"
case '\x02':
return "string"
case '\x03':
return "embedded document"
case '\x04':
return "array"
case '\x05':
return "binary"
case '\x06':
return "undefined"
case '\x07':
return "objectID"
case '\x08':
return "boolean"
case '\x09':
return "UTC datetime"
case '\x0A':
return "null"
case '\x0B':
return "regex"
case '\x0C':
return "dbPointer"
case '\x0D':
return "javascript"
case '\x0E':
return "symbol"
case '\x0F':
return "code with scope"
case '\x10':
return "32-bit integer"
case '\x11':
return "timestamp"
case '\x12':
return "64-bit integer"
case '\x13':
return "128-bit decimal"
case '\x7F':
return "max key"
case '\xFF':
return "min key"
default:
return "invalid"
}
}
// BSON element types as described in https://bsonspec.org/spec.html.
const (
TypeDouble Type = 0x01
TypeString Type = 0x02
TypeEmbeddedDocument Type = 0x03
TypeArray Type = 0x04
TypeBinary Type = 0x05
TypeUndefined Type = 0x06
TypeObjectID Type = 0x07
TypeBoolean Type = 0x08
TypeDateTime Type = 0x09
TypeNull Type = 0x0A
TypeRegex Type = 0x0B
TypeDBPointer Type = 0x0C
TypeJavaScript Type = 0x0D
TypeSymbol Type = 0x0E
TypeCodeWithScope Type = 0x0F
TypeInt32 Type = 0x10
TypeTimestamp Type = 0x11
TypeInt64 Type = 0x12
TypeDecimal128 Type = 0x13
TypeMaxKey Type = 0x7F
TypeMinKey Type = 0xFF
)
File diff suppressed because it is too large Load Diff