Skip to content

Commit a840d75

Browse files
authored
Strict validation when using SQL DB for visibility (#3905)
1 parent 52c3a9e commit a840d75

17 files changed

+376
-193
lines changed

common/persistence/visibility/defs.go

+17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424

2525
package visibility
2626

27+
import (
28+
"go.temporal.io/server/common/persistence/sql/sqlplugin/mysql"
29+
"go.temporal.io/server/common/persistence/sql/sqlplugin/postgresql"
30+
"go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite"
31+
)
32+
2733
const (
2834
// AdvancedVisibilityWritingModeOff means do not write to advanced visibility store
2935
AdvancedVisibilityWritingModeOff = "off"
@@ -40,3 +46,14 @@ func DefaultAdvancedVisibilityWritingMode(advancedVisibilityConfigExist bool) st
4046
}
4147
return AdvancedVisibilityWritingModeOff
4248
}
49+
50+
func AllowListForValidation(pluginName string) bool {
51+
switch pluginName {
52+
case mysql.PluginNameV8, postgresql.PluginNameV12, sqlite.PluginName:
53+
// Advanced visibility with SQL DB don't support list of values
54+
return false
55+
default:
56+
// Otherwise, enable for backward compatibility.
57+
return true
58+
}
59+
}

common/persistence/visibility/store/elasticsearch/visibility_store.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ func (s *visibilityStore) generateESDoc(request *store.InternalVisibilityRequest
874874
return nil, serviceerror.NewUnavailable(fmt.Sprintf("Unable to read search attribute types: %v", err))
875875
}
876876

877-
searchAttributes, err := searchattribute.Decode(request.SearchAttributes, &typeMap)
877+
searchAttributes, err := searchattribute.Decode(request.SearchAttributes, &typeMap, true)
878878
if err != nil {
879879
s.metricsHandler.Counter(metrics.ElasticsearchDocumentGenerateFailuresCount.GetMetricName()).Record(1)
880880
return nil, serviceerror.NewInternal(fmt.Sprintf("Unable to decode search attributes: %v", err))

common/persistence/visibility/store/elasticsearch/visibility_store_read_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ func (s *ESVisibilitySuite) TestParseESDoc_SearchAttributes() {
920920
info, err := s.visibilityStore.parseESDoc("", docSource, searchattribute.TestNameTypeMap, testNamespace)
921921
s.NoError(err)
922922
s.NotNil(info)
923-
customSearchAttributes, err := searchattribute.Decode(info.SearchAttributes, &searchattribute.TestNameTypeMap)
923+
customSearchAttributes, err := searchattribute.Decode(info.SearchAttributes, &searchattribute.TestNameTypeMap, true)
924924
s.NoError(err)
925925

926926
s.Len(customSearchAttributes, 7)

common/persistence/visibility/store/sql/visibility_store.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ func (s *VisibilityStore) prepareSearchAttributesForDb(
482482
}
483483

484484
var searchAttributes sqlplugin.VisibilitySearchAttributes
485-
searchAttributes, err = searchattribute.Decode(request.SearchAttributes, &saTypeMap)
485+
searchAttributes, err = searchattribute.Decode(request.SearchAttributes, &saTypeMap, false)
486486
if err != nil {
487487
return nil, err
488488
}

common/searchattribute/encode.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ func Encode(searchAttributes map[string]interface{}, typeMap *NameTypeMap) (*com
6767
// 1. type from typeMap,
6868
// 2. if typeMap is nil, type from MetadataType field is used.
6969
// In case of error, it will continue to next search attribute and return last error.
70-
func Decode(searchAttributes *commonpb.SearchAttributes, typeMap *NameTypeMap) (map[string]interface{}, error) {
70+
func Decode(
71+
searchAttributes *commonpb.SearchAttributes,
72+
typeMap *NameTypeMap,
73+
allowList bool,
74+
) (map[string]interface{}, error) {
7175
if len(searchAttributes.GetIndexedFields()) == 0 {
7276
return nil, nil
7377
}
@@ -84,7 +88,7 @@ func Decode(searchAttributes *commonpb.SearchAttributes, typeMap *NameTypeMap) (
8488
}
8589
}
8690

87-
searchAttributeValue, err := DecodeValue(saPayload, saType)
91+
searchAttributeValue, err := DecodeValue(saPayload, saType, allowList)
8892
if err != nil {
8993
lastErr = err
9094
result[saName] = nil

common/searchattribute/encode_test.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func Test_Decode_Success(t *testing.T) {
137137
}, typeMap)
138138
assert.NoError(err)
139139

140-
vals, err := Decode(sa, typeMap)
140+
vals, err := Decode(sa, typeMap, true)
141141
assert.NoError(err)
142142
assert.Len(vals, 6)
143143
assert.Equal("val1", vals["key1"])
@@ -154,7 +154,7 @@ func Test_Decode_Success(t *testing.T) {
154154
delete(sa.IndexedFields["key5"].Metadata, "type")
155155
delete(sa.IndexedFields["key6"].Metadata, "type")
156156

157-
vals, err = Decode(sa, typeMap)
157+
vals, err = Decode(sa, typeMap, true)
158158
assert.NoError(err)
159159
assert.Len(vals, 6)
160160
assert.Equal("val1", vals["key1"])
@@ -185,7 +185,7 @@ func Test_Decode_NilMap(t *testing.T) {
185185
}, typeMap)
186186
assert.NoError(err)
187187

188-
vals, err := Decode(sa, nil)
188+
vals, err := Decode(sa, nil, true)
189189
assert.NoError(err)
190190
assert.Len(sa.IndexedFields, 6)
191191
assert.Equal("val1", vals["key1"])
@@ -211,11 +211,15 @@ func Test_Decode_Error(t *testing.T) {
211211
}, typeMap)
212212
assert.NoError(err)
213213

214-
vals, err := Decode(sa, &NameTypeMap{customSearchAttributes: map[string]enumspb.IndexedValueType{
215-
"key1": enumspb.INDEXED_VALUE_TYPE_TEXT,
216-
"key4": enumspb.INDEXED_VALUE_TYPE_INT,
217-
"key3": enumspb.INDEXED_VALUE_TYPE_BOOL,
218-
}})
214+
vals, err := Decode(
215+
sa,
216+
&NameTypeMap{customSearchAttributes: map[string]enumspb.IndexedValueType{
217+
"key1": enumspb.INDEXED_VALUE_TYPE_TEXT,
218+
"key4": enumspb.INDEXED_VALUE_TYPE_INT,
219+
"key3": enumspb.INDEXED_VALUE_TYPE_BOOL,
220+
}},
221+
true,
222+
)
219223
assert.Error(err)
220224
assert.True(errors.Is(err, ErrInvalidName))
221225
assert.Len(sa.IndexedFields, 3)
@@ -227,7 +231,7 @@ func Test_Decode_Error(t *testing.T) {
227231
delete(sa.IndexedFields["key2"].Metadata, "type")
228232
delete(sa.IndexedFields["key3"].Metadata, "type")
229233

230-
vals, err = Decode(sa, nil)
234+
vals, err = Decode(sa, nil, true)
231235
assert.Error(err)
232236
assert.True(errors.Is(err, ErrInvalidType))
233237
assert.Len(vals, 3)

common/searchattribute/encode_value.go

+52-74
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,44 @@ func EncodeValue(val interface{}, t enumspb.IndexedValueType) (*commonpb.Payload
4848
// DecodeValue decodes search attribute value from Payload using (in order):
4949
// 1. passed type t.
5050
// 2. type from MetadataType field, if t is not specified.
51-
func DecodeValue(value *commonpb.Payload, t enumspb.IndexedValueType) (interface{}, error) {
51+
// allowList allows list of values when it's not keyword list type.
52+
func DecodeValue(
53+
value *commonpb.Payload,
54+
t enumspb.IndexedValueType,
55+
allowList bool,
56+
) (any, error) {
5257
if t == enumspb.INDEXED_VALUE_TYPE_UNSPECIFIED {
53-
t = enumspb.IndexedValueType(enumspb.IndexedValueType_value[string(value.Metadata[MetadataType])])
58+
t = enumspb.IndexedValueType(
59+
enumspb.IndexedValueType_value[string(value.Metadata[MetadataType])],
60+
)
5461
}
5562

56-
// Here are similar code sections for all types.
63+
switch t {
64+
case enumspb.INDEXED_VALUE_TYPE_BOOL:
65+
return decodeValueTyped[bool](value, allowList)
66+
case enumspb.INDEXED_VALUE_TYPE_DATETIME:
67+
return decodeValueTyped[time.Time](value, allowList)
68+
case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
69+
return decodeValueTyped[float64](value, allowList)
70+
case enumspb.INDEXED_VALUE_TYPE_INT:
71+
return decodeValueTyped[int64](value, allowList)
72+
case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
73+
return decodeValueTyped[string](value, allowList)
74+
case enumspb.INDEXED_VALUE_TYPE_TEXT:
75+
return decodeValueTyped[string](value, allowList)
76+
case enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST:
77+
return decodeValueTyped[[]string](value, false)
78+
default:
79+
return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
80+
}
81+
}
82+
83+
// decodeValueTyped tries to decode to the given type.
84+
// If the input is a list and allowList is false, then it will return only the first element.
85+
// If the input is a list and allowList is true, then it will return the decoded list.
86+
//
87+
//nolint:revive // allowList is a control flag
88+
func decodeValueTyped[T any](value *commonpb.Payload, allowList bool) (any, error) {
5789
// At first, it tries to decode to pointer of actual type (i.e. `*string` for `string`).
5890
// This is to ensure that `nil` values are decoded back as `nil` using `NilPayloadConverter`.
5991
// If value is not `nil` but some value of expected type, the code relies on the fact that
@@ -62,82 +94,28 @@ func DecodeValue(value *commonpb.Payload, t enumspb.IndexedValueType) (interface
6294
// If decoding to pointer type fails, it tries to decode to array of the same type because
6395
// search attributes support polymorphism: field of specific type may also have an array of that type.
6496
// If resulting slice has zero length, it gets substitute with `nil` to treat nils and empty slices equally.
97+
// If allowList is true, it returns the list as it is. If allowList is false and the list has
98+
// only one element, then return it. Otherwise, return an error.
6599
// If search attribute value is `nil`, it means that search attribute needs to be removed from the document.
66-
67-
switch t {
68-
case enumspb.INDEXED_VALUE_TYPE_TEXT,
69-
enumspb.INDEXED_VALUE_TYPE_KEYWORD,
70-
enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST:
71-
var val *string
72-
if err := payload.Decode(value, &val); err != nil {
73-
var listVal []string
74-
err = payload.Decode(value, &listVal)
75-
if len(listVal) == 0 {
76-
return nil, err
77-
}
78-
return listVal, err
79-
}
80-
if val == nil {
81-
return nil, nil
100+
var val *T
101+
if err := payload.Decode(value, &val); err != nil {
102+
var listVal []T
103+
if err := payload.Decode(value, &listVal); err != nil {
104+
return nil, err
82105
}
83-
return *val, nil
84-
case enumspb.INDEXED_VALUE_TYPE_INT:
85-
var val *int64
86-
if err := payload.Decode(value, &val); err != nil {
87-
var listVal []int64
88-
err = payload.Decode(value, &listVal)
89-
if len(listVal) == 0 {
90-
return nil, err
91-
}
92-
return listVal, err
93-
}
94-
if val == nil {
95-
return nil, nil
96-
}
97-
return *val, nil
98-
case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
99-
var val *float64
100-
if err := payload.Decode(value, &val); err != nil {
101-
var listVal []float64
102-
err = payload.Decode(value, &listVal)
103-
if len(listVal) == 0 {
104-
return nil, err
105-
}
106-
return listVal, err
107-
}
108-
if val == nil {
106+
if len(listVal) == 0 {
109107
return nil, nil
110108
}
111-
return *val, nil
112-
case enumspb.INDEXED_VALUE_TYPE_BOOL:
113-
var val *bool
114-
if err := payload.Decode(value, &val); err != nil {
115-
var listVal []bool
116-
err = payload.Decode(value, &listVal)
117-
if len(listVal) == 0 {
118-
return nil, err
119-
}
120-
return listVal, err
109+
if allowList {
110+
return listVal, nil
121111
}
122-
if val == nil {
123-
return nil, nil
112+
if len(listVal) == 1 {
113+
return listVal[0], nil
124114
}
125-
return *val, nil
126-
case enumspb.INDEXED_VALUE_TYPE_DATETIME:
127-
var val *time.Time
128-
if err := payload.Decode(value, &val); err != nil {
129-
var listVal []time.Time
130-
err = payload.Decode(value, &listVal)
131-
if len(listVal) == 0 {
132-
return nil, err
133-
}
134-
return listVal, err
135-
}
136-
if val == nil {
137-
return nil, nil
138-
}
139-
return *val, nil
140-
default:
141-
return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
115+
return nil, fmt.Errorf("list of values not allowed for type %T", listVal[0])
116+
}
117+
if val == nil {
118+
return nil, nil
142119
}
120+
return *val, nil
143121
}

0 commit comments

Comments
 (0)