@@ -7,6 +7,7 @@ export type SourcesResolver = ReturnType<typeof getSources>;
7
7
8
8
type SourcesOptions = {
9
9
select : LibPgQueryAST . SelectStmt ;
10
+ prevSources ?: Map < string , SelectSource > ;
10
11
nonNullableColumns : Set < string > ;
11
12
pgColsBySchemaAndTableName : Map < string , Map < string , PgColRow [ ] > > ;
12
13
relations : FlattenedRelationWithJoins [ ] ;
@@ -17,51 +18,65 @@ export type ResolvedColumn = {
17
18
isNotNull : boolean ;
18
19
} ;
19
20
21
+ type SelectSource =
22
+ | {
23
+ kind : "table" ;
24
+ schemaName : string ;
25
+ name : string ;
26
+ original : string ;
27
+ alias ?: string ;
28
+ columns : ResolvedColumn [ ] ;
29
+ }
30
+ | { kind : "cte" | "subselect" ; name : string ; sources : SourcesResolver } ;
31
+
20
32
type TargetField =
21
33
| { kind : "unknown" ; field : string }
22
34
| { kind : "column" ; table : string ; column : string } ;
23
35
24
36
export function getSources ( {
25
37
pgColsBySchemaAndTableName,
26
38
relations,
39
+ prevSources,
27
40
select,
28
41
nonNullableColumns,
29
42
} : SourcesOptions ) {
30
- const { columns, sources : sourcesEntries } = getColumnSources ( select . fromClause ?? [ ] ) ;
31
- const sources = new Map ( sourcesEntries ) ;
32
-
33
- function getAllResolvedColumns ( ) {
34
- return columns . map ( ( x ) => resolveColumn ( x . column ) ) ;
43
+ const ctes = getColumnCTEs ( select . withClause ?. ctes ?? [ ] ) ;
44
+ const sources : Map < string , SelectSource > = new Map ( [
45
+ ...( prevSources ?. entries ( ) ?? [ ] ) ,
46
+ ...getColumnSources ( select . fromClause ?? [ ] ) . entries ( ) ,
47
+ ] ) ;
48
+
49
+ function getSourceColumns ( source : SelectSource ) {
50
+ switch ( source . kind ) {
51
+ case "cte" :
52
+ case "subselect" :
53
+ return source . sources . getAllResolvedColumns ( ) ;
54
+ case "table" :
55
+ return source . columns . map ( ( column ) => ( { column, source } ) ) ;
56
+ }
35
57
}
36
58
37
- function getResolvedColumnsInTable ( sourceName : string ) {
38
- return columns . filter ( ( x ) => x . source . name === sourceName ) . map ( ( x ) => resolveColumn ( x . column ) ) ;
59
+ function getAllResolvedColumns ( ) : { column : ResolvedColumn ; source : SelectSource } [ ] {
60
+ return [ ... sources . values ( ) ] . map ( getSourceColumns ) . flat ( ) ;
39
61
}
40
62
41
- function getColumnByTableAndColumnName ( p : { table : string ; column : string } ) {
42
- const columnSource = columns . find ( ( x ) => {
43
- if ( x . column . colName !== p . column ) {
44
- return false ;
45
- }
63
+ function getResolvedColumnsInTable ( sourceName : string ) : ResolvedColumn [ ] {
64
+ return fmap ( sources . get ( sourceName ) , getSourceColumns ) ?. map ( ( x ) => x . column ) ?? [ ] ;
65
+ }
46
66
47
- switch ( x . source . kind ) {
48
- case "table" :
49
- return ( x . source . alias ?? x . source . name ) === p . table ;
50
- case "subselect" :
51
- return x . source . name === p . table ;
52
- }
53
- } ) ;
67
+ function getColumnByTableAndColumnName ( p : {
68
+ table : string ;
69
+ column : string ;
70
+ } ) : ResolvedColumn | null {
71
+ const source = sources . get ( p . table ) ;
54
72
55
- if ( columnSource === undefined ) {
73
+ if ( source === undefined ) {
56
74
return null ;
57
75
}
58
76
59
- const resolved =
60
- columnSource . source . kind === "table" && columnSource . source . alias !== undefined
61
- ? resolveColumn ( { ...columnSource . column , tableName : columnSource . source . alias } )
62
- : resolveColumn ( columnSource . column ) ;
77
+ const resolved = getSourceColumns ( source ) . find ( ( x ) => x . column . column . colName === p . column ) ;
63
78
64
- return resolved ;
79
+ return resolved ?. column ?? null ;
65
80
}
66
81
67
82
function getColumnsByTargetField ( field : TargetField ) : ResolvedColumn [ ] | null {
@@ -73,12 +88,12 @@ export function getSources({
73
88
const source = sources . get ( field . field ) ;
74
89
75
90
if ( source !== undefined ) {
76
- return columns . filter ( ( x ) => x . source === source ) . map ( ( x ) => resolveColumn ( x . column ) ) ;
91
+ return getSourceColumns ( source ) . map ( ( x ) => x . column ) ;
77
92
}
78
93
79
- for ( const { column } of columns ) {
80
- if ( column . colName === field . field ) {
81
- return [ resolveColumn ( column ) ] ;
94
+ for ( const { column } of getAllResolvedColumns ( ) ) {
95
+ if ( column . column . colName === field . field ) {
96
+ return [ column ] ;
82
97
}
83
98
}
84
99
@@ -87,8 +102,8 @@ export function getSources({
87
102
}
88
103
}
89
104
90
- function checkIsNullableDueToRelation ( column : PgColRow ) {
91
- const findByJoin = relations . find ( ( x ) => ( x . alias ?? x . joinRelName ) === column . tableName ) ;
105
+ function checkIsNullableDueToRelation ( tableName : string ) : boolean {
106
+ const findByJoin = relations . find ( ( x ) => ( x . alias ?? x . joinRelName ) === tableName ) ;
92
107
93
108
if ( findByJoin !== undefined ) {
94
109
switch ( findByJoin . joinType ) {
@@ -109,7 +124,7 @@ export function getSources({
109
124
}
110
125
}
111
126
112
- const findByRel = relations . filter ( ( x ) => x . relName === column . tableName ) ;
127
+ const findByRel = relations . filter ( ( x ) => x . relName === tableName ) ;
113
128
114
129
for ( const rel of findByRel ) {
115
130
switch ( rel . joinType ) {
@@ -133,27 +148,17 @@ export function getSources({
133
148
return false ;
134
149
}
135
150
136
- function resolveColumn ( col : PgColRow ) : ResolvedColumn {
137
- const isNullableDueToRelation = checkIsNullableDueToRelation ( col ) ;
151
+ function resolveColumn ( col : PgColRow , tableName : string ) : ResolvedColumn {
152
+ const isNullableDueToRelation = checkIsNullableDueToRelation ( tableName ) ;
138
153
const isNotNullBasedOnAST =
139
- nonNullableColumns . has ( col . colName ) ||
140
- nonNullableColumns . has ( `${ col . tableName } .${ col . colName } ` ) ;
154
+ nonNullableColumns . has ( col . colName ) || nonNullableColumns . has ( `${ tableName } .${ col . colName } ` ) ;
141
155
const isNotNullInTable = col . colNotNull ;
142
156
143
157
const isNonNullable = isNotNullBasedOnAST || ( isNotNullInTable && ! isNullableDueToRelation ) ;
144
158
145
159
return { column : col , isNotNull : isNonNullable } ;
146
160
}
147
161
148
- type SelectSource =
149
- | { kind : "table" ; schemaName : string ; name : string ; original : string ; alias ?: string }
150
- | { kind : "subselect" ; name : string } ;
151
-
152
- type ColumnWithSource = {
153
- column : PgColRow ;
154
- source : SelectSource ;
155
- } ;
156
-
157
162
function resolveRangeVarSchema ( node : LibPgQueryAST . RangeVar ) : string {
158
163
if ( node . schemaname !== undefined ) {
159
164
return node . schemaname ;
@@ -172,66 +177,93 @@ export function getSources({
172
177
return "public" ;
173
178
}
174
179
175
- function getColumnSources ( nodes : LibPgQueryAST . Node [ ] ) : {
176
- columns : ColumnWithSource [ ] ;
177
- sources : [ string , SelectSource ] [ ] ;
178
- } {
179
- const columns : ColumnWithSource [ ] = [ ] ;
180
- const sources : [ string , SelectSource ] [ ] = [ ] ;
180
+ function getColumnCTEs ( ctes : LibPgQueryAST . Node [ ] ) : Map < string , SourcesResolver > {
181
+ const map = new Map < string , SourcesResolver > ( ) ;
181
182
182
- for ( const node of nodes ) {
183
- if ( node . RangeVar !== undefined ) {
184
- const source : SelectSource = {
185
- kind : "table" ,
186
- schemaName : resolveRangeVarSchema ( node . RangeVar ) ,
187
- original : node . RangeVar . relname ,
188
- name : node . RangeVar . alias ?. aliasname ?? node . RangeVar . relname ,
189
- alias : node . RangeVar . alias ?. aliasname ,
190
- } ;
183
+ for ( const cte of ctes ) {
184
+ if ( cte . CommonTableExpr ?. ctequery ?. SelectStmt === undefined ) continue ;
185
+ if ( cte . CommonTableExpr ?. ctename === undefined ) continue ;
191
186
192
- sources . push ( [ source . name , source ] ) ;
187
+ const resolver = getSources ( {
188
+ pgColsBySchemaAndTableName,
189
+ prevSources,
190
+ nonNullableColumns,
191
+ relations,
192
+ select : cte . CommonTableExpr . ctequery . SelectStmt ,
193
+ } ) ;
193
194
194
- for ( const column of pgColsBySchemaAndTableName
195
- . get ( source . schemaName )
196
- ?. get ( source . original ) ?? [ ] ) {
197
- columns . push ( { column, source } ) ;
198
- }
199
- }
195
+ map . set ( cte . CommonTableExpr . ctename , resolver ) ;
196
+ }
200
197
201
- if ( node . JoinExpr ?. larg !== undefined ) {
202
- const resolved = getColumnSources ( [ node . JoinExpr . larg ] ) ;
203
- columns . push ( ...resolved . columns ) ;
204
- sources . push ( ...resolved . sources ) ;
205
- }
198
+ return map ;
199
+ }
206
200
207
- if ( node . JoinExpr ?. rarg !== undefined ) {
208
- const resolved = getColumnSources ( [ node . JoinExpr . rarg ] ) ;
209
- columns . push ( ...resolved . columns ) ;
210
- sources . push ( ...resolved . sources ) ;
201
+ function getNodeColumnAndSources ( node : LibPgQueryAST . Node ) : SelectSource [ ] {
202
+ if ( node . RangeVar !== undefined ) {
203
+ const cte = ctes . get ( node . RangeVar . relname ) ;
204
+
205
+ if ( cte !== undefined ) {
206
+ return [ { kind : "cte" , name : node . RangeVar . relname , sources : cte } ] ;
211
207
}
212
208
213
- if ( node . RangeSubselect ?. subquery ?. SelectStmt ?. fromClause !== undefined ) {
214
- const source : SelectSource = {
215
- kind : "subselect" ,
216
- name : node . RangeSubselect . alias ?. aliasname ?? "subselect" ,
217
- } ;
209
+ const schemaName = resolveRangeVarSchema ( node . RangeVar ) ;
210
+ const realTableName = node . RangeVar . relname ;
211
+ const tableName = node . RangeVar . alias ?. aliasname ?? realTableName ;
212
+ const tableColumns = pgColsBySchemaAndTableName . get ( schemaName ) ?. get ( realTableName ) ?? [ ] ;
218
213
219
- sources . push ( [ source . name , source ] ) ;
214
+ return [
215
+ {
216
+ kind : "table" ,
217
+ schemaName : schemaName ,
218
+ original : realTableName ,
219
+ name : node . RangeVar . alias ?. aliasname ?? node . RangeVar . relname ,
220
+ alias : node . RangeVar . alias ?. aliasname ,
221
+ columns : tableColumns . map ( ( col ) => resolveColumn ( col , tableName ) ) ,
222
+ } ,
223
+ ] ;
224
+ }
220
225
221
- const resolvedColumns = getColumnSources (
222
- node . RangeSubselect . subquery . SelectStmt . fromClause ,
223
- ) . columns . map ( ( x ) => x . column ) ;
226
+ const sources : SelectSource [ ] = [ ] ;
224
227
225
- for ( const column of resolvedColumns ) {
226
- columns . push ( { column, source } ) ;
227
- }
228
- }
228
+ if ( node . JoinExpr ?. larg !== undefined ) {
229
+ sources . push ( ...getNodeColumnAndSources ( node . JoinExpr . larg ) ) ;
230
+ }
231
+
232
+ if ( node . JoinExpr ?. rarg !== undefined ) {
233
+ sources . push ( ...getNodeColumnAndSources ( node . JoinExpr . rarg ) ) ;
229
234
}
230
235
231
- return { columns, sources } ;
236
+ if ( node . RangeSubselect ?. subquery ?. SelectStmt ?. fromClause !== undefined ) {
237
+ sources . push ( {
238
+ kind : "subselect" ,
239
+ name : node . RangeSubselect . alias ?. aliasname ?? "subselect" ,
240
+ sources : getSources ( {
241
+ nonNullableColumns,
242
+ pgColsBySchemaAndTableName,
243
+ relations,
244
+ prevSources : new Map ( [
245
+ ...( prevSources ?. entries ( ) ?? [ ] ) ,
246
+ ...sources . map ( ( x ) => [ x . name , x ] as const ) ,
247
+ ] ) ,
248
+ select : node . RangeSubselect . subquery . SelectStmt ,
249
+ } ) ,
250
+ } ) ;
251
+ }
252
+
253
+ return sources ;
254
+ }
255
+
256
+ function getColumnSources ( nodes : LibPgQueryAST . Node [ ] ) : Map < string , SelectSource > {
257
+ return new Map (
258
+ nodes
259
+ . map ( getNodeColumnAndSources )
260
+ . flat ( )
261
+ . map ( ( x ) => [ x . name , x ] ) ,
262
+ ) ;
232
263
}
233
264
234
265
return {
266
+ getNodeColumnAndSources : getNodeColumnAndSources ,
235
267
getResolvedColumnsInTable : getResolvedColumnsInTable ,
236
268
getAllResolvedColumns : getAllResolvedColumns ,
237
269
getColumnsByTargetField : getColumnsByTargetField ,
0 commit comments