Skip to content

Commit 42326c2

Browse files
committed
fix: improve source resolution
1 parent b63d77c commit 42326c2

File tree

4 files changed

+198
-98
lines changed

4 files changed

+198
-98
lines changed

.changeset/slow-stingrays-know.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@ts-safeql/generate": patch
3+
---
4+
5+
fixed an issue when the wrong type was returned in some cases when using CTEs

packages/generate/src/ast-describe.ts

+9-8
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ function getDescribedCoalesceExpr({
439439
.at(0);
440440

441441
if (type === undefined) {
442-
return [unknownCoalesce];
442+
return [];
443443
}
444444

445445
return [
@@ -794,7 +794,7 @@ function getDescribedColumnRef({
794794
return getDescribedColumnByResolvedColumns({
795795
alias: alias,
796796
context: context,
797-
resolved: context.resolver.getAllResolvedColumns(),
797+
resolved: context.resolver.getAllResolvedColumns().map((x) => x.column),
798798
});
799799
}
800800

@@ -822,15 +822,16 @@ function getDescribedColumnRef({
822822
}
823823

824824
if (isColumnTableColumnRef(node.fields)) {
825+
const resolved = context.resolver.getColumnsByTargetField({
826+
kind: "column",
827+
table: node.fields[0].String.sval,
828+
column: node.fields[1].String.sval,
829+
});
830+
825831
return getDescribedColumnByResolvedColumns({
826832
alias: alias,
827833
context: context,
828-
resolved:
829-
context.resolver.getColumnsByTargetField({
830-
kind: "column",
831-
table: node.fields[0].String.sval,
832-
column: node.fields[1].String.sval,
833-
}) ?? [],
834+
resolved: resolved ?? [],
834835
});
835836
}
836837

packages/generate/src/ast-get-sources.ts

+122-90
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export type SourcesResolver = ReturnType<typeof getSources>;
77

88
type SourcesOptions = {
99
select: LibPgQueryAST.SelectStmt;
10+
prevSources?: Map<string, SelectSource>;
1011
nonNullableColumns: Set<string>;
1112
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
1213
relations: FlattenedRelationWithJoins[];
@@ -17,51 +18,65 @@ export type ResolvedColumn = {
1718
isNotNull: boolean;
1819
};
1920

21+
type SelectSource =
22+
| {
23+
kind: "table";
24+
schemaName: string;
25+
name: string;
26+
original: string;
27+
alias?: string;
28+
columns: ResolvedColumn[];
29+
}
30+
| { kind: "cte" | "subselect"; name: string; sources: SourcesResolver };
31+
2032
type TargetField =
2133
| { kind: "unknown"; field: string }
2234
| { kind: "column"; table: string; column: string };
2335

2436
export function getSources({
2537
pgColsBySchemaAndTableName,
2638
relations,
39+
prevSources,
2740
select,
2841
nonNullableColumns,
2942
}: SourcesOptions) {
30-
const { columns, sources: sourcesEntries } = getColumnSources(select.fromClause ?? []);
31-
const sources = new Map(sourcesEntries);
32-
33-
function getAllResolvedColumns() {
34-
return columns.map((x) => resolveColumn(x.column));
43+
const ctes = getColumnCTEs(select.withClause?.ctes ?? []);
44+
const sources: Map<string, SelectSource> = new Map([
45+
...(prevSources?.entries() ?? []),
46+
...getColumnSources(select.fromClause ?? []).entries(),
47+
]);
48+
49+
function getSourceColumns(source: SelectSource) {
50+
switch (source.kind) {
51+
case "cte":
52+
case "subselect":
53+
return source.sources.getAllResolvedColumns();
54+
case "table":
55+
return source.columns.map((column) => ({ column, source }));
56+
}
3557
}
3658

37-
function getResolvedColumnsInTable(sourceName: string) {
38-
return columns.filter((x) => x.source.name === sourceName).map((x) => resolveColumn(x.column));
59+
function getAllResolvedColumns(): { column: ResolvedColumn; source: SelectSource }[] {
60+
return [...sources.values()].map(getSourceColumns).flat();
3961
}
4062

41-
function getColumnByTableAndColumnName(p: { table: string; column: string }) {
42-
const columnSource = columns.find((x) => {
43-
if (x.column.colName !== p.column) {
44-
return false;
45-
}
63+
function getResolvedColumnsInTable(sourceName: string): ResolvedColumn[] {
64+
return fmap(sources.get(sourceName), getSourceColumns)?.map((x) => x.column) ?? [];
65+
}
4666

47-
switch (x.source.kind) {
48-
case "table":
49-
return (x.source.alias ?? x.source.name) === p.table;
50-
case "subselect":
51-
return x.source.name === p.table;
52-
}
53-
});
67+
function getColumnByTableAndColumnName(p: {
68+
table: string;
69+
column: string;
70+
}): ResolvedColumn | null {
71+
const source = sources.get(p.table);
5472

55-
if (columnSource === undefined) {
73+
if (source === undefined) {
5674
return null;
5775
}
5876

59-
const resolved =
60-
columnSource.source.kind === "table" && columnSource.source.alias !== undefined
61-
? resolveColumn({ ...columnSource.column, tableName: columnSource.source.alias })
62-
: resolveColumn(columnSource.column);
77+
const resolved = getSourceColumns(source).find((x) => x.column.column.colName === p.column);
6378

64-
return resolved;
79+
return resolved?.column ?? null;
6580
}
6681

6782
function getColumnsByTargetField(field: TargetField): ResolvedColumn[] | null {
@@ -73,12 +88,12 @@ export function getSources({
7388
const source = sources.get(field.field);
7489

7590
if (source !== undefined) {
76-
return columns.filter((x) => x.source === source).map((x) => resolveColumn(x.column));
91+
return getSourceColumns(source).map((x) => x.column);
7792
}
7893

79-
for (const { column } of columns) {
80-
if (column.colName === field.field) {
81-
return [resolveColumn(column)];
94+
for (const { column } of getAllResolvedColumns()) {
95+
if (column.column.colName === field.field) {
96+
return [column];
8297
}
8398
}
8499

@@ -87,8 +102,8 @@ export function getSources({
87102
}
88103
}
89104

90-
function checkIsNullableDueToRelation(column: PgColRow) {
91-
const findByJoin = relations.find((x) => (x.alias ?? x.joinRelName) === column.tableName);
105+
function checkIsNullableDueToRelation(tableName: string): boolean {
106+
const findByJoin = relations.find((x) => (x.alias ?? x.joinRelName) === tableName);
92107

93108
if (findByJoin !== undefined) {
94109
switch (findByJoin.joinType) {
@@ -109,7 +124,7 @@ export function getSources({
109124
}
110125
}
111126

112-
const findByRel = relations.filter((x) => x.relName === column.tableName);
127+
const findByRel = relations.filter((x) => x.relName === tableName);
113128

114129
for (const rel of findByRel) {
115130
switch (rel.joinType) {
@@ -133,27 +148,17 @@ export function getSources({
133148
return false;
134149
}
135150

136-
function resolveColumn(col: PgColRow): ResolvedColumn {
137-
const isNullableDueToRelation = checkIsNullableDueToRelation(col);
151+
function resolveColumn(col: PgColRow, tableName: string): ResolvedColumn {
152+
const isNullableDueToRelation = checkIsNullableDueToRelation(tableName);
138153
const isNotNullBasedOnAST =
139-
nonNullableColumns.has(col.colName) ||
140-
nonNullableColumns.has(`${col.tableName}.${col.colName}`);
154+
nonNullableColumns.has(col.colName) || nonNullableColumns.has(`${tableName}.${col.colName}`);
141155
const isNotNullInTable = col.colNotNull;
142156

143157
const isNonNullable = isNotNullBasedOnAST || (isNotNullInTable && !isNullableDueToRelation);
144158

145159
return { column: col, isNotNull: isNonNullable };
146160
}
147161

148-
type SelectSource =
149-
| { kind: "table"; schemaName: string; name: string; original: string; alias?: string }
150-
| { kind: "subselect"; name: string };
151-
152-
type ColumnWithSource = {
153-
column: PgColRow;
154-
source: SelectSource;
155-
};
156-
157162
function resolveRangeVarSchema(node: LibPgQueryAST.RangeVar): string {
158163
if (node.schemaname !== undefined) {
159164
return node.schemaname;
@@ -172,66 +177,93 @@ export function getSources({
172177
return "public";
173178
}
174179

175-
function getColumnSources(nodes: LibPgQueryAST.Node[]): {
176-
columns: ColumnWithSource[];
177-
sources: [string, SelectSource][];
178-
} {
179-
const columns: ColumnWithSource[] = [];
180-
const sources: [string, SelectSource][] = [];
180+
function getColumnCTEs(ctes: LibPgQueryAST.Node[]): Map<string, SourcesResolver> {
181+
const map = new Map<string, SourcesResolver>();
181182

182-
for (const node of nodes) {
183-
if (node.RangeVar !== undefined) {
184-
const source: SelectSource = {
185-
kind: "table",
186-
schemaName: resolveRangeVarSchema(node.RangeVar),
187-
original: node.RangeVar.relname,
188-
name: node.RangeVar.alias?.aliasname ?? node.RangeVar.relname,
189-
alias: node.RangeVar.alias?.aliasname,
190-
};
183+
for (const cte of ctes) {
184+
if (cte.CommonTableExpr?.ctequery?.SelectStmt === undefined) continue;
185+
if (cte.CommonTableExpr?.ctename === undefined) continue;
191186

192-
sources.push([source.name, source]);
187+
const resolver = getSources({
188+
pgColsBySchemaAndTableName,
189+
prevSources,
190+
nonNullableColumns,
191+
relations,
192+
select: cte.CommonTableExpr.ctequery.SelectStmt,
193+
});
193194

194-
for (const column of pgColsBySchemaAndTableName
195-
.get(source.schemaName)
196-
?.get(source.original) ?? []) {
197-
columns.push({ column, source });
198-
}
199-
}
195+
map.set(cte.CommonTableExpr.ctename, resolver);
196+
}
200197

201-
if (node.JoinExpr?.larg !== undefined) {
202-
const resolved = getColumnSources([node.JoinExpr.larg]);
203-
columns.push(...resolved.columns);
204-
sources.push(...resolved.sources);
205-
}
198+
return map;
199+
}
206200

207-
if (node.JoinExpr?.rarg !== undefined) {
208-
const resolved = getColumnSources([node.JoinExpr.rarg]);
209-
columns.push(...resolved.columns);
210-
sources.push(...resolved.sources);
201+
function getNodeColumnAndSources(node: LibPgQueryAST.Node): SelectSource[] {
202+
if (node.RangeVar !== undefined) {
203+
const cte = ctes.get(node.RangeVar.relname);
204+
205+
if (cte !== undefined) {
206+
return [{ kind: "cte", name: node.RangeVar.relname, sources: cte }];
211207
}
212208

213-
if (node.RangeSubselect?.subquery?.SelectStmt?.fromClause !== undefined) {
214-
const source: SelectSource = {
215-
kind: "subselect",
216-
name: node.RangeSubselect.alias?.aliasname ?? "subselect",
217-
};
209+
const schemaName = resolveRangeVarSchema(node.RangeVar);
210+
const realTableName = node.RangeVar.relname;
211+
const tableName = node.RangeVar.alias?.aliasname ?? realTableName;
212+
const tableColumns = pgColsBySchemaAndTableName.get(schemaName)?.get(realTableName) ?? [];
218213

219-
sources.push([source.name, source]);
214+
return [
215+
{
216+
kind: "table",
217+
schemaName: schemaName,
218+
original: realTableName,
219+
name: node.RangeVar.alias?.aliasname ?? node.RangeVar.relname,
220+
alias: node.RangeVar.alias?.aliasname,
221+
columns: tableColumns.map((col) => resolveColumn(col, tableName)),
222+
},
223+
];
224+
}
220225

221-
const resolvedColumns = getColumnSources(
222-
node.RangeSubselect.subquery.SelectStmt.fromClause,
223-
).columns.map((x) => x.column);
226+
const sources: SelectSource[] = [];
224227

225-
for (const column of resolvedColumns) {
226-
columns.push({ column, source });
227-
}
228-
}
228+
if (node.JoinExpr?.larg !== undefined) {
229+
sources.push(...getNodeColumnAndSources(node.JoinExpr.larg));
230+
}
231+
232+
if (node.JoinExpr?.rarg !== undefined) {
233+
sources.push(...getNodeColumnAndSources(node.JoinExpr.rarg));
229234
}
230235

231-
return { columns, sources };
236+
if (node.RangeSubselect?.subquery?.SelectStmt?.fromClause !== undefined) {
237+
sources.push({
238+
kind: "subselect",
239+
name: node.RangeSubselect.alias?.aliasname ?? "subselect",
240+
sources: getSources({
241+
nonNullableColumns,
242+
pgColsBySchemaAndTableName,
243+
relations,
244+
prevSources: new Map([
245+
...(prevSources?.entries() ?? []),
246+
...sources.map((x) => [x.name, x] as const),
247+
]),
248+
select: node.RangeSubselect.subquery.SelectStmt,
249+
}),
250+
});
251+
}
252+
253+
return sources;
254+
}
255+
256+
function getColumnSources(nodes: LibPgQueryAST.Node[]): Map<string, SelectSource> {
257+
return new Map(
258+
nodes
259+
.map(getNodeColumnAndSources)
260+
.flat()
261+
.map((x) => [x.name, x]),
262+
);
232263
}
233264

234265
return {
266+
getNodeColumnAndSources: getNodeColumnAndSources,
235267
getResolvedColumnsInTable: getResolvedColumnsInTable,
236268
getAllResolvedColumns: getAllResolvedColumns,
237269
getColumnsByTargetField: getColumnsByTargetField,

0 commit comments

Comments
 (0)