Skip to content

Commit cbed0fe

Browse files
authored
Fix wrap user query string in parenthesis (#3967)
1 parent 7ab7e9c commit cbed0fe

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

common/persistence/visibility/store/sql/query_converter.go

+11
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,17 @@ func (c *QueryConverter) convertSelectStmt(sel *sqlparser.Select) error {
218218
if err != nil {
219219
return err
220220
}
221+
222+
// Wrap user's query in parenthesis. This is to ensure that further changes
223+
// to the query won't affect the user's query.
224+
switch sel.Where.Expr.(type) {
225+
case *sqlparser.ParenExpr:
226+
// no-op: top-level expression is already a parenthesis
227+
default:
228+
sel.Where.Expr = &sqlparser.ParenExpr{
229+
Expr: sel.Where.Expr,
230+
}
231+
}
221232
}
222233

223234
// This logic comes from elasticsearch/visibility_store.go#convertQuery function.

common/persistence/visibility/store/sql/query_converter_mysql.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ func (c *mysqlQueryConverter) buildSelectStmt(
222222
queryArgs = append(queryArgs, namespaceID.String())
223223

224224
if len(queryString) > 0 {
225-
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
225+
whereClauses = append(whereClauses, queryString)
226226
}
227227

228228
if token != nil {
@@ -283,7 +283,7 @@ func (c *mysqlQueryConverter) buildCountStmt(
283283
queryArgs = append(queryArgs, namespaceID.String())
284284

285285
if len(queryString) > 0 {
286-
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
286+
whereClauses = append(whereClauses, queryString)
287287
}
288288

289289
return fmt.Sprintf(

common/persistence/visibility/store/sql/query_converter_postgresql.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func (c *pgQueryConverter) buildSelectStmt(
229229
queryArgs = append(queryArgs, namespaceID.String())
230230

231231
if len(queryString) > 0 {
232-
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
232+
whereClauses = append(whereClauses, queryString)
233233
}
234234

235235
if token != nil {
@@ -286,7 +286,7 @@ func (c *pgQueryConverter) buildCountStmt(
286286
queryArgs = append(queryArgs, namespaceID.String())
287287

288288
if len(queryString) > 0 {
289-
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
289+
whereClauses = append(whereClauses, queryString)
290290
}
291291

292292
return fmt.Sprintf(

common/persistence/visibility/store/sql/query_converter_sqlite.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (c *sqliteQueryConverter) buildSelectStmt(
240240
queryArgs = append(queryArgs, namespaceID.String())
241241

242242
if len(queryString) > 0 {
243-
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
243+
whereClauses = append(whereClauses, queryString)
244244
}
245245

246246
if token != nil {
@@ -328,7 +328,7 @@ func (c *sqliteQueryConverter) buildCountStmt(
328328
queryArgs = append(queryArgs, namespaceID.String())
329329

330330
if len(queryString) > 0 {
331-
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
331+
whereClauses = append(whereClauses, queryString)
332332
}
333333

334334
return fmt.Sprintf(

0 commit comments

Comments
 (0)