Skip to content

Commit 3abd50d

Browse files
authored
Better SQL query splitter (#3791)
1 parent 9389673 commit 3abd50d

File tree

6 files changed

+255
-90
lines changed

6 files changed

+255
-90
lines changed

common/persistence/query_util.go

+112-8
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,29 @@
2525
package persistence
2626

2727
import (
28+
"bytes"
2829
"fmt"
2930
"io"
3031
"os"
3132
"strings"
33+
"unicode"
3234
)
3335

3436
const (
35-
queryDelimiter = ";"
37+
queryDelimiter = ';'
3638
querySliceDefaultSize = 100
39+
40+
sqlLeftParenthesis = '('
41+
sqlRightParenthesis = ')'
42+
sqlBeginKeyword = "begin"
43+
sqlEndKeyword = "end"
44+
sqlLineComment = "--"
45+
sqlSingleQuote = '\''
46+
sqlDoubleQuote = '"'
3747
)
3848

39-
// LoadAndSplitQuery loads and split cql / sql query into one statement per string
49+
// LoadAndSplitQuery loads and split cql / sql query into one statement per string.
50+
// Comments are removed from the query.
4051
func LoadAndSplitQuery(
4152
filePaths []string,
4253
) ([]string, error) {
@@ -53,26 +64,119 @@ func LoadAndSplitQuery(
5364
return LoadAndSplitQueryFromReaders(files)
5465
}
5566

56-
// LoadAndSplitQueryFromReaders loads and split cql / sql query into one statement per string
67+
// LoadAndSplitQueryFromReaders loads and split cql / sql query into one statement per string.
68+
// Comments are removed from the query.
5769
func LoadAndSplitQueryFromReaders(
5870
readers []io.Reader,
5971
) ([]string, error) {
60-
6172
result := make([]string, 0, querySliceDefaultSize)
62-
6373
for _, r := range readers {
6474
content, err := io.ReadAll(r)
6575
if err != nil {
6676
return nil, fmt.Errorf("error reading contents: %w", err)
6777
}
68-
for _, stmt := range strings.Split(string(content), queryDelimiter) {
69-
stmt = strings.TrimSpace(stmt)
78+
n := len(content)
79+
contentStr := string(bytes.ToLower(content))
80+
for i, j := 0, 0; i < n; i = j {
81+
// stack to keep track of open parenthesis/blocks
82+
var st []byte
83+
var stmtBuilder strings.Builder
84+
85+
stmtLoop:
86+
for ; j < n; j++ {
87+
switch contentStr[j] {
88+
case queryDelimiter:
89+
if len(st) == 0 {
90+
j++
91+
break stmtLoop
92+
}
93+
94+
case sqlLeftParenthesis:
95+
st = append(st, sqlLeftParenthesis)
96+
97+
case sqlRightParenthesis:
98+
if len(st) == 0 || st[len(st)-1] != sqlLeftParenthesis {
99+
return nil, fmt.Errorf("error reading contents: unmatched right parenthesis")
100+
}
101+
st = st[:len(st)-1]
102+
103+
case sqlBeginKeyword[0]:
104+
if hasWordAt(contentStr, sqlBeginKeyword, j) {
105+
st = append(st, sqlBeginKeyword[0])
106+
j += len(sqlBeginKeyword) - 1
107+
}
108+
109+
case sqlEndKeyword[0]:
110+
if hasWordAt(contentStr, sqlEndKeyword, j) {
111+
if len(st) == 0 || st[len(st)-1] != sqlBeginKeyword[0] {
112+
return nil, fmt.Errorf("error reading contents: unmatched `END` keyword")
113+
}
114+
st = st[:len(st)-1]
115+
j += len(sqlEndKeyword) - 1
116+
}
117+
118+
case sqlSingleQuote, sqlDoubleQuote:
119+
quote := contentStr[j]
120+
j++
121+
for j < n && contentStr[j] != quote {
122+
j++
123+
}
124+
if j == n {
125+
return nil, fmt.Errorf("error reading contents: unmatched quotes")
126+
}
127+
128+
case sqlLineComment[0]:
129+
if j+len(sqlLineComment) <= n && contentStr[j:j+len(sqlLineComment)] == sqlLineComment {
130+
_, _ = stmtBuilder.Write(bytes.TrimRight(content[i:j], " "))
131+
for j < n && contentStr[j] != '\n' {
132+
j++
133+
}
134+
i = j
135+
}
136+
137+
default:
138+
// no-op: generic character
139+
}
140+
}
141+
142+
if len(st) > 0 {
143+
switch st[len(st)-1] {
144+
case sqlLeftParenthesis:
145+
return nil, fmt.Errorf("error reading contents: unmatched left parenthesis")
146+
case sqlBeginKeyword[0]:
147+
return nil, fmt.Errorf("error reading contents: unmatched `BEGIN` keyword")
148+
default:
149+
// should never enter here
150+
return nil, fmt.Errorf("error reading contents: unmatched `%c`", st[len(st)-1])
151+
}
152+
}
153+
154+
_, _ = stmtBuilder.Write(content[i:j])
155+
stmt := strings.TrimSpace(stmtBuilder.String())
70156
if stmt == "" {
71157
continue
72158
}
73159
result = append(result, stmt)
74160
}
75-
76161
}
77162
return result, nil
78163
}
164+
165+
// hasWordAt is a simple test to check if it matches the whole word:
166+
// it checks if the adjacent charactes are not alphanumeric if they exist.
167+
func hasWordAt(s, word string, pos int) bool {
168+
if pos+len(word) > len(s) || s[pos:pos+len(word)] != word {
169+
return false
170+
}
171+
if pos > 0 && isAlphanumeric(s[pos-1]) {
172+
return false
173+
}
174+
if pos+len(word) < len(s) && isAlphanumeric(s[pos+len(word)]) {
175+
return false
176+
}
177+
return true
178+
}
179+
180+
func isAlphanumeric(c byte) bool {
181+
return unicode.IsLetter(rune(c)) || unicode.IsDigit(rune(c))
182+
}

common/persistence/query_util_test.go

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// The MIT License
2+
//
3+
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
4+
//
5+
// Copyright (c) 2020 Uber Technologies, Inc.
6+
//
7+
// Permission is hereby granted, free of charge, to any person obtaining a copy
8+
// of this software and associated documentation files (the "Software"), to deal
9+
// in the Software without restriction, including without limitation the rights
10+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
// copies of the Software, and to permit persons to whom the Software is
12+
// furnished to do so, subject to the following conditions:
13+
//
14+
// The above copyright notice and this permission notice shall be included in
15+
// all copies or substantial portions of the Software.
16+
//
17+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
// THE SOFTWARE.
24+
25+
package persistence
26+
27+
import (
28+
"bytes"
29+
"io"
30+
"testing"
31+
32+
"github.com/stretchr/testify/require"
33+
"github.com/stretchr/testify/suite"
34+
35+
"go.temporal.io/server/common/log"
36+
)
37+
38+
type (
39+
queryUtilSuite struct {
40+
suite.Suite
41+
// override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test,
42+
// not merely log an error
43+
*require.Assertions
44+
logger log.Logger
45+
}
46+
)
47+
48+
func TestQueryUtilSuite(t *testing.T) {
49+
s := new(queryUtilSuite)
50+
suite.Run(t, s)
51+
}
52+
53+
func (s *queryUtilSuite) SetupTest() {
54+
s.logger = log.NewTestLogger()
55+
// Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil
56+
s.Assertions = require.New(s.T())
57+
}
58+
59+
func (s *queryUtilSuite) TestLoadAndSplitQueryFromReaders() {
60+
input := `
61+
CREATE TABLE test (
62+
id BIGINT not null,
63+
col1 BIGINT, -- comment with unmatched parenthesis )
64+
col2 VARCHAR(255),
65+
PRIMARY KEY (id)
66+
);
67+
68+
CREATE INDEX test_idx ON test (col1);
69+
70+
--begin
71+
CREATE TRIGGER test_ai AFTER INSERT ON test
72+
BEGIN
73+
SELECT *, 'string with unmatched chars ")' FROM test;
74+
--end
75+
END;
76+
77+
-- trailing comment
78+
`
79+
statements, err := LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
80+
s.NoError(err)
81+
s.Equal(3, len(statements))
82+
s.Equal(
83+
`CREATE TABLE test (
84+
id BIGINT not null,
85+
col1 BIGINT,
86+
col2 VARCHAR(255),
87+
PRIMARY KEY (id)
88+
);`,
89+
statements[0],
90+
)
91+
s.Equal(`CREATE INDEX test_idx ON test (col1);`, statements[1])
92+
// comments are removed, but the inner content is not trimmed
93+
s.Equal(
94+
`CREATE TRIGGER test_ai AFTER INSERT ON test
95+
BEGIN
96+
SELECT *, 'string with unmatched chars ")' FROM test;
97+
98+
END;`,
99+
statements[2],
100+
)
101+
102+
input = "CREATE TABLE test (;"
103+
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
104+
s.Error(err, "error reading contents: unmatched left parenthesis")
105+
s.Nil(statements)
106+
107+
input = "CREATE TABLE test ());"
108+
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
109+
s.Error(err, "error reading contents: unmatched right parenthesis")
110+
s.Nil(statements)
111+
112+
input = "begin"
113+
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
114+
s.Error(err, "error reading contents: unmatched `BEGIN` keyword")
115+
s.Nil(statements)
116+
117+
input = "end"
118+
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
119+
s.Error(err, "error reading contents: unmatched `END` keyword")
120+
s.Nil(statements)
121+
122+
input = "select ' from test;"
123+
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
124+
s.Error(err, "error reading contents: unmatched quotes")
125+
s.Nil(statements)
126+
}
127+
128+
func (s *queryUtilSuite) TestHasWordAt() {
129+
s.True(hasWordAt("BEGIN", "BEGIN", 0))
130+
s.True(hasWordAt(" BEGIN ", "BEGIN", 1))
131+
s.True(hasWordAt(")BEGIN;", "BEGIN", 1))
132+
s.False(hasWordAt("BEGIN", "BEGIN", 1))
133+
s.False(hasWordAt("sBEGIN", "BEGIN", 1))
134+
s.False(hasWordAt("BEGINs", "BEGIN", 0))
135+
s.False(hasWordAt("7BEGIN", "BEGIN", 1))
136+
s.False(hasWordAt("BEGIN7", "BEGIN", 0))
137+
}

tools/common/schema/setuptask.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232

3333
"go.temporal.io/server/common/log"
3434
"go.temporal.io/server/common/log/tag"
35+
"go.temporal.io/server/common/persistence"
3536
)
3637

3738
// SetupTask represents a task
@@ -75,7 +76,7 @@ func (task *SetupTask) Run() error {
7576
if err != nil {
7677
return err
7778
}
78-
stmts, err := ParseFile(filePath)
79+
stmts, err := persistence.LoadAndSplitQuery([]string{filePath})
7980
if err != nil {
8081
return err
8182
}

tools/common/schema/test/dbtest.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434

3535
"go.temporal.io/server/common/log"
3636
"go.temporal.io/server/common/log/tag"
37+
"go.temporal.io/server/common/persistence"
3738
"go.temporal.io/server/tests/testutils"
3839
"go.temporal.io/server/tools/common/schema"
3940
)
@@ -83,7 +84,7 @@ func (tb *DBTestBase) RunParseFileTest(content string) {
8384

8485
_, err := cqlFile.WriteString(content)
8586
tb.NoError(err)
86-
stmts, err := schema.ParseFile(cqlFile.Name())
87+
stmts, err := persistence.LoadAndSplitQuery([]string{cqlFile.Name()})
8788
tb.Nil(err)
8889
tb.Equal(2, len(stmts), "wrong number of sql statements")
8990
}

tools/common/schema/updatetask.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import (
4141

4242
"go.temporal.io/server/common/log"
4343
"go.temporal.io/server/common/log/tag"
44+
"go.temporal.io/server/common/persistence"
4445
)
4546

4647
type (
@@ -230,7 +231,7 @@ func (task *UpdateTask) parseSQLStmts(dir string, manifest *manifest) ([]string,
230231
for _, file := range manifest.SchemaUpdateCqlFiles {
231232
path := dir + "/" + file
232233
task.logger.Info("Processing schema file: " + path)
233-
stmts, err := ParseFile(path)
234+
stmts, err := persistence.LoadAndSplitQuery([]string{path})
234235
if err != nil {
235236
return nil, fmt.Errorf("error parsing file %v, err=%v", path, err)
236237
}

0 commit comments

Comments
 (0)