1
- // Copyright 2012, Google Inc. All rights reserved.
2
- // Use of this source code is governed by a BSD-style
3
- // license that can be found in the LICENSE file.
1
+ /*
2
+ Copyright 2017 Google Inc.
4
3
5
- //Modified by Wenbin Xiao 2015.04.18
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ */
6
16
7
17
package sqlparser
8
18
@@ -11,87 +21,128 @@ package sqlparser
11
21
import (
12
22
"errors"
13
23
"fmt"
24
+ "strconv"
25
+ "strings"
26
+ "unicode"
27
+
14
28
"github.com/xwb1989/sqlparser/dependency/sqltypes"
15
29
)
16
30
17
- // GetTableName returns the table name from the SimpleTableExpr
18
- // only if it's a simple expression. Otherwise, it returns "".
19
- func GetTableName (node SimpleTableExpr ) string {
20
- if n , ok := node .(* TableName ); ok && n .Qualifier == nil {
21
- return string (n .Name )
31
+ // These constants are used to identify the SQL statement type.
32
+ const (
33
+ StmtSelect = iota
34
+ StmtInsert
35
+ StmtReplace
36
+ StmtUpdate
37
+ StmtDelete
38
+ StmtDDL
39
+ StmtBegin
40
+ StmtCommit
41
+ StmtRollback
42
+ StmtSet
43
+ StmtShow
44
+ StmtUse
45
+ StmtOther
46
+ StmtUnknown
47
+ )
48
+
49
+ // Preview analyzes the beginning of the query using a simpler and faster
50
+ // textual comparison to identify the statement type.
51
+ func Preview (sql string ) int {
52
+ trimmed := StripLeadingComments (sql )
53
+
54
+ firstWord := trimmed
55
+ if end := strings .IndexFunc (trimmed , unicode .IsSpace ); end != - 1 {
56
+ firstWord = trimmed [:end ]
22
57
}
23
- // sub-select or '.' expression
24
- return ""
25
- }
26
58
27
- // Get the primary key ColumnDefinition of the table, sqlNode must be a CreateTable struct
28
- func GetPrimaryKey (sqlNode SQLNode ) (* ColumnDefinition , error ) {
29
- node , ok := sqlNode .(* CreateTable )
30
- if ! ok {
31
- return nil , errors .New ("fail to convert interface SQLNode to struct CreateTable" )
59
+ // Comparison is done in order of priority.
60
+ loweredFirstWord := strings .ToLower (firstWord )
61
+ switch loweredFirstWord {
62
+ case "select" :
63
+ return StmtSelect
64
+ case "insert" :
65
+ return StmtInsert
66
+ case "replace" :
67
+ return StmtReplace
68
+ case "update" :
69
+ return StmtUpdate
70
+ case "delete" :
71
+ return StmtDelete
32
72
}
33
- for _ , col := range node .ColumnDefinitions {
34
- for _ , att := range col .ColumnAtts {
35
- if att == AST_PRIMARY_KEY {
36
- return col , nil
37
- }
38
- }
73
+ switch strings .ToLower (trimmed ) {
74
+ case "begin" , "start transaction" :
75
+ return StmtBegin
76
+ case "commit" :
77
+ return StmtCommit
78
+ case "rollback" :
79
+ return StmtRollback
39
80
}
40
- return nil , errors .New ("unable to find primary key" )
81
+ switch loweredFirstWord {
82
+ case "create" , "alter" , "rename" , "drop" :
83
+ return StmtDDL
84
+ case "set" :
85
+ return StmtSet
86
+ case "show" :
87
+ return StmtShow
88
+ case "use" :
89
+ return StmtUse
90
+ case "analyze" , "describe" , "desc" , "explain" , "repair" , "optimize" , "truncate" :
91
+ return StmtOther
92
+ }
93
+ return StmtUnknown
41
94
}
42
95
43
- //Get ColumnDefinition by name, sqlNode must be a CreateTable struct
44
- func GetColumnByName (sqlNode SQLNode , name string ) (* ColumnDefinition , error ) {
45
- node , ok := sqlNode .(* CreateTable )
46
- if ! ok {
47
- return nil , errors .New ("fail to convert interface SQLNode to struct CreateTable" )
48
- }
49
- for _ , col := range node .ColumnDefinitions {
50
- if col .ColName == name {
51
- return col , nil
52
- }
96
+ // IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
97
+ func IsDML (sql string ) bool {
98
+ switch Preview (sql ) {
99
+ case StmtInsert , StmtReplace , StmtUpdate , StmtDelete :
100
+ return true
53
101
}
54
- return nil , errors . New ( "unable to find the column" )
102
+ return false
55
103
}
56
104
57
- // GetColName returns the column name, only if
58
- // it's a simple expression. Otherwise, it returns "".
59
- func GetColName (node Expr ) string {
60
- if n , ok := node .(* ColName ); ok {
61
- return string ( n .Name )
105
+ // GetTableName returns the table name from the SimpleTableExpr
106
+ // only if it's a simple expression. Otherwise, it returns "".
107
+ func GetTableName (node SimpleTableExpr ) TableIdent {
108
+ if n , ok := node .(TableName ); ok && n . Qualifier . IsEmpty () {
109
+ return n .Name
62
110
}
63
- return ""
111
+ // sub-select or '.' expression
112
+ return NewTableIdent ("" )
64
113
}
65
114
66
- // IsColName returns true if the ValExpr is a *ColName.
67
- func IsColName (node ValExpr ) bool {
115
+ // IsColName returns true if the Expr is a *ColName.
116
+ func IsColName (node Expr ) bool {
68
117
_ , ok := node .(* ColName )
69
118
return ok
70
119
}
71
120
72
- // IsValue returns true if the ValExpr is a string, number or value arg.
121
+ // IsValue returns true if the Expr is a string, integral or value arg.
73
122
// NULL is not considered to be a value.
74
- func IsValue (node ValExpr ) bool {
75
- switch node .(type ) {
76
- case StrVal , NumVal , ValArg :
77
- return true
123
+ func IsValue (node Expr ) bool {
124
+ switch v := node .(type ) {
125
+ case * SQLVal :
126
+ switch v .Type {
127
+ case StrVal , HexVal , IntVal , ValArg :
128
+ return true
129
+ }
78
130
}
79
131
return false
80
132
}
81
133
82
- // HasINCaluse returns true if any of the conditions has an IN clause.
83
- func HasINClause (conditions []BoolExpr ) bool {
84
- for _ , node := range conditions {
85
- if c , ok := node .(* ComparisonExpr ); ok && c .Operator == AST_IN {
86
- return true
87
- }
134
+ // IsNull returns true if the Expr is SQL NULL
135
+ func IsNull (node Expr ) bool {
136
+ switch node .(type ) {
137
+ case * NullVal :
138
+ return true
88
139
}
89
140
return false
90
141
}
91
142
92
- // IsSimpleTuple returns true if the ValExpr is a ValTuple that
143
+ // IsSimpleTuple returns true if the Expr is a ValTuple that
93
144
// contains simple values or if it's a list arg.
94
- func IsSimpleTuple (node ValExpr ) bool {
145
+ func IsSimpleTuple (node Expr ) bool {
95
146
switch vals := node .(type ) {
96
147
case ValTuple :
97
148
for _ , n := range vals {
@@ -107,38 +158,49 @@ func IsSimpleTuple(node ValExpr) bool {
107
158
return false
108
159
}
109
160
110
- // AsInterface converts the ValExpr to an interface. It converts
111
- // ValTuple to []interface{}, ValArg to string, StrVal to sqltypes.String,
112
- // NumVal to sqltypes.Numeric, NullVal to nil.
113
- // Otherwise, it returns an error.
114
- func AsInterface (node ValExpr ) (interface {}, error ) {
161
+ // NewPlanValue builds a sqltypes.PlanValue from an Expr.
162
+ func NewPlanValue (node Expr ) (sqltypes.PlanValue , error ) {
115
163
switch node := node .(type ) {
116
- case ValTuple :
117
- vals := make ([]interface {}, 0 , len (node ))
118
- for _ , val := range node {
119
- v , err := AsInterface (val )
164
+ case * SQLVal :
165
+ switch node .Type {
166
+ case ValArg :
167
+ return sqltypes.PlanValue {Key : string (node .Val [1 :])}, nil
168
+ case IntVal :
169
+ n , err := sqltypes .NewIntegral (string (node .Val ))
120
170
if err != nil {
121
- return nil , err
171
+ return sqltypes. PlanValue {}, fmt . Errorf ( "%v" , err )
122
172
}
123
- vals = append (vals , v )
173
+ return sqltypes.PlanValue {Value : n }, nil
174
+ case StrVal :
175
+ return sqltypes.PlanValue {Value : sqltypes .MakeTrusted (sqltypes .VarBinary , node .Val )}, nil
176
+ case HexVal :
177
+ v , err := node .HexDecode ()
178
+ if err != nil {
179
+ return sqltypes.PlanValue {}, fmt .Errorf ("%v" , err )
180
+ }
181
+ return sqltypes.PlanValue {Value : sqltypes .MakeTrusted (sqltypes .VarBinary , v )}, nil
124
182
}
125
- return vals , nil
126
- case ValArg :
127
- return string (node ), nil
128
183
case ListArg :
129
- return string (node ), nil
130
- case StrVal :
131
- return sqltypes .MakeString (node ), nil
132
- case NumVal :
133
- n , err := sqltypes .BuildNumeric (string (node ))
134
- if err != nil {
135
- return nil , fmt .Errorf ("type mismatch: %s" , err )
184
+ return sqltypes.PlanValue {ListKey : string (node [2 :])}, nil
185
+ case ValTuple :
186
+ pv := sqltypes.PlanValue {
187
+ Values : make ([]sqltypes.PlanValue , 0 , len (node )),
136
188
}
137
- return n , nil
189
+ for _ , val := range node {
190
+ innerpv , err := NewPlanValue (val )
191
+ if err != nil {
192
+ return sqltypes.PlanValue {}, err
193
+ }
194
+ if innerpv .ListKey != "" || innerpv .Values != nil {
195
+ return sqltypes.PlanValue {}, errors .New ("unsupported: nested lists" )
196
+ }
197
+ pv .Values = append (pv .Values , innerpv )
198
+ }
199
+ return pv , nil
138
200
case * NullVal :
139
- return nil , nil
201
+ return sqltypes. PlanValue {} , nil
140
202
}
141
- return nil , fmt .Errorf ("unexpected node %v " , node )
203
+ return sqltypes. PlanValue {} , fmt .Errorf ("expression is too complex '%v' " , String ( node ) )
142
204
}
143
205
144
206
// StringIn is a convenience function that returns
@@ -151,3 +213,48 @@ func StringIn(str string, values ...string) bool {
151
213
}
152
214
return false
153
215
}
216
+
217
+ // ExtractSetValues returns a map of key-value pairs
218
+ // if the query is a SET statement. Values can be int64 or string.
219
+ // Since set variable names are case insensitive, all keys are returned
220
+ // as lower case.
221
+ func ExtractSetValues (sql string ) (keyValues map [string ]interface {}, charset string , err error ) {
222
+ stmt , err := Parse (sql )
223
+ if err != nil {
224
+ return nil , "" , err
225
+ }
226
+ setStmt , ok := stmt .(* Set )
227
+ if ! ok {
228
+ return nil , "" , fmt .Errorf ("ast did not yield *sqlparser.Set: %T" , stmt )
229
+ }
230
+ result := make (map [string ]interface {})
231
+ for _ , expr := range setStmt .Exprs {
232
+ if ! expr .Name .Qualifier .IsEmpty () {
233
+ return nil , "" , fmt .Errorf ("invalid syntax: %v" , String (expr .Name ))
234
+ }
235
+ key := expr .Name .Name .Lowered ()
236
+
237
+ switch expr := expr .Expr .(type ) {
238
+ case * SQLVal :
239
+ switch expr .Type {
240
+ case StrVal :
241
+ result [key ] = string (expr .Val )
242
+ case IntVal :
243
+ num , err := strconv .ParseInt (string (expr .Val ), 0 , 64 )
244
+ if err != nil {
245
+ return nil , "" , err
246
+ }
247
+ result [key ] = num
248
+ default :
249
+ return nil , "" , fmt .Errorf ("invalid value type: %v" , String (expr ))
250
+ }
251
+ case * NullVal :
252
+ result [key ] = nil
253
+ case * Default :
254
+ result [key ] = "default"
255
+ default :
256
+ return nil , "" , fmt .Errorf ("invalid syntax: %s" , String (expr ))
257
+ }
258
+ }
259
+ return result , setStmt .Charset .Lowered (), nil
260
+ }
0 commit comments