Skip to content

Commit 84dd78c

Browse files
olavloiteskuruppu
andauthored
feat: support NUMERIC as key (#3627)
Co-authored-by: skuruppu <skuruppu@google.com>
1 parent 8617812 commit 84dd78c

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

spanner/client_test.go

+25
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"fmt"
2222
"io"
23+
"math/big"
2324
"os"
2425
"strings"
2526
"testing"
@@ -2267,3 +2268,27 @@ func TestClient_DoForEachRow_ShouldEndSpanWithQueryError(t *testing.T) {
22672268
t.Errorf("Span status mismatch\nGot: %v\nWant: %v", s.Code, codes.InvalidArgument)
22682269
}
22692270
}
2271+
2272+
func TestClient_Single_Read_WithNumericKey(t *testing.T) {
2273+
t.Parallel()
2274+
2275+
_, client, teardown := setupMockedTestServer(t)
2276+
defer teardown()
2277+
ctx := context.Background()
2278+
iter := client.Single().Read(ctx, "Albums", KeySets(Key{*big.NewRat(1, 1)}), []string{"SingerId", "AlbumId", "AlbumTitle"})
2279+
defer iter.Stop()
2280+
rowCount := int64(0)
2281+
for {
2282+
_, err := iter.Next()
2283+
if err == iterator.Done {
2284+
break
2285+
}
2286+
if err != nil {
2287+
t.Fatal(err)
2288+
}
2289+
rowCount++
2290+
}
2291+
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
2292+
t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
2293+
}
2294+
}

spanner/key.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package spanner
1919
import (
2020
"bytes"
2121
"fmt"
22+
"math/big"
2223
"time"
2324

2425
"cloud.google.com/go/civil"
@@ -84,7 +85,7 @@ func keyPartValue(part interface{}) (pb *proto3.Value, err error) {
8485
pb, _, err = encodeValue(int64(v))
8586
case float32:
8687
pb, _, err = encodeValue(float64(v))
87-
case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate:
88+
case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate, big.Rat, NullNumeric:
8889
pb, _, err = encodeValue(v)
8990
case Encoder:
9091
part, err = v.EncodeSpanner()
@@ -150,7 +151,7 @@ func (key Key) elemString(b *bytes.Buffer, part interface{}) {
150151
} else {
151152
fmt.Fprint(b, nullString)
152153
}
153-
case NullInt64, NullFloat64, NullBool:
154+
case NullInt64, NullFloat64, NullBool, NullNumeric:
154155
// The above types implement fmt.Stringer.
155156
fmt.Fprintf(b, "%s", v)
156157
case NullString, NullDate, NullTime:
@@ -164,6 +165,8 @@ func (key Key) elemString(b *bytes.Buffer, part interface{}) {
164165
fmt.Fprintf(b, "%q", v)
165166
case time.Time:
166167
fmt.Fprintf(b, "%q", v.Format(time.RFC3339Nano))
168+
case big.Rat:
169+
fmt.Fprintf(b, "%v", NumericString(&v))
167170
case Encoder:
168171
var err error
169172
part, err = v.EncodeSpanner()

spanner/key_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package spanner
1818

1919
import (
2020
"errors"
21+
"math/big"
2122
"testing"
2223
"time"
2324

@@ -132,6 +133,11 @@ func TestKey(t *testing.T) {
132133
wantProto: listValueProto(stringProto("2016-11-15")),
133134
wantStr: `("2016-11-15")`,
134135
},
136+
{
137+
k: Key{*big.NewRat(1, 1)},
138+
wantProto: listValueProto(stringProto("1.000000000")),
139+
wantStr: `(1.000000000)`,
140+
},
135141
{
136142
k: Key{[]byte("value")},
137143
wantProto: listValueProto(bytesProto([]byte("value"))),
@@ -203,6 +209,16 @@ func TestKey(t *testing.T) {
203209
wantProto: listValueProto(stringProto("1"), nullProto(), stringProto("value"), floatProto(1.5), boolProto(true)),
204210
wantStr: `(1,<null>,"value",1.5,true)`,
205211
},
212+
{
213+
k: Key{NullNumeric{*big.NewRat(2, 3), true}},
214+
wantProto: listValueProto(stringProto("0.666666667")),
215+
wantStr: "(0.666666667)",
216+
},
217+
{
218+
k: Key{NullNumeric{big.Rat{}, false}},
219+
wantProto: listValueProto(nullProto()),
220+
wantStr: "(<null>)",
221+
},
206222
{
207223
k: Key{customKeyToString("value")},
208224
wantProto: listValueProto(stringProto("value")),

0 commit comments

Comments
 (0)