Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct type inference for arithmetic operations #359

Merged
merged 8 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/chatty-hairs-hide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@ts-safeql/generate": patch
---

fixed an issue where the inferred typed was incorrect when dealing with arithmetic operations
2 changes: 1 addition & 1 deletion packages/eslint-plugin/src/rules/check-sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ RuleTester.describe("check-sql", () => {
await sql<Caregiver[]>\`
SELECT
CASE WHEN caregiver.id IS NOT NULL
THEN jsonb_build_object('is_test', caregiver.middle_name NOT LIKE '%test%')
THEN jsonb_build_object('is_test', caregiver.first_name LIKE '%test%')
ELSE NULL
END AS meta
FROM
Expand Down
2 changes: 1 addition & 1 deletion packages/generate/src/ast-decribe.utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ export function isSingleCell<T>(arr: T[]): arr is [T] {
return arr.length === 1;
}

function isTuple<T>(arr: T[]): arr is [T, T] {
export function isTuple<T>(arr: T[]): arr is [T, T] {
return arr.length === 2;
}
133 changes: 112 additions & 21 deletions packages/generate/src/ast-describe.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { fmap, normalizeIndent } from "@ts-safeql/shared";
import { defaultTypeExprMapping, fmap, normalizeIndent } from "@ts-safeql/shared";
import * as LibPgQueryAST from "@ts-safeql/sql-ast";
import {
isColumnStarRef,
isColumnTableColumnRef,
isColumnTableStarRef,
isColumnUnknownRef,
isSingleCell,
isTuple,
} from "./ast-decribe.utils";
import { ResolvedColumn, SourcesResolver, getSources } from "./ast-get-sources";
import { PgColRow, PgEnumsMaps, PgTypesMap } from "./generate";
Expand All @@ -20,7 +21,7 @@ type ASTDescriptionOptions = {
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
pgTypes: PgTypesMap;
pgEnums: PgEnumsMaps;
pgFns: Map<string, string>;
pgFns: Map<string, { ts: string; pg: string }>;
};

type ASTDescriptionContext = ASTDescriptionOptions & {
Expand All @@ -38,7 +39,7 @@ export type ASTDescribedColumnType =
| { kind: "union"; value: ASTDescribedColumnType[] }
| { kind: "array"; value: ASTDescribedColumnType }
| { kind: "object"; value: [string, ASTDescribedColumnType][] }
| { kind: "type"; value: string }
| { kind: "type"; value: string; type: string }
| { kind: "literal"; value: string; base: ASTDescribedColumnType };

export function getASTDescription(params: ASTDescriptionOptions): Map<number, ASTDescribedColumn> {
Expand Down Expand Up @@ -82,20 +83,32 @@ export function getASTDescription(params: ASTDescriptionOptions): Map<number, AS
p: { oid: number; baseOid: number | null } | { name: string },
): ASTDescribedColumnType => {
if ("name" in p) {
return { kind: "type", value: params.typesMap.get(p.name)?.value ?? "unknown" };
return {
kind: "type",
value: params.typesMap.get(p.name)?.value ?? "unknown",
type: p.name,
};
}

const typeByOid = getTypeByOid(p.oid);

if (typeByOid.override) {
const baseType: ASTDescribedColumnType = { kind: "type", value: typeByOid.value };
const baseType: ASTDescribedColumnType = {
kind: "type",
value: typeByOid.value,
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
};
return typeByOid.isArray ? { kind: "array", value: baseType } : baseType;
}

const typeByBaseOid = fmap(p.baseOid, getTypeByOid);

if (typeByBaseOid?.override === true) {
const baseType: ASTDescribedColumnType = { kind: "type", value: typeByBaseOid.value };
const baseType: ASTDescribedColumnType = {
kind: "type",
value: typeByBaseOid.value,
type: params.pgTypes.get(p.baseOid!)?.name ?? "unknown",
};
return typeByBaseOid.isArray ? { kind: "array", value: baseType } : baseType;
}

Expand All @@ -104,13 +117,21 @@ export function getASTDescription(params: ASTDescriptionOptions): Map<number, AS
if (enumValue !== undefined) {
return {
kind: "union",
value: enumValue.values.map((value) => ({ kind: "type", value: `'${value}'` })),
value: enumValue.values.map((value) => ({
kind: "type",
value: `'${value}'`,
type: enumValue.name,
})),
};
}

const { isArray, value } = typeByBaseOid ?? typeByOid;

const type: ASTDescribedColumnType = { kind: "type", value: value };
const type: ASTDescribedColumnType = {
kind: "type",
value: value,
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
};

return isArray ? { kind: "array", value: type } : type;
},
Expand Down Expand Up @@ -215,15 +236,81 @@ function getDescribedNode(params: {

function getDescribedAExpr({
alias,
node,
context,
}: GetDescribedParamsOf<LibPgQueryAST.AExpr>): ASTDescribedColumn[] {
const name = alias ?? "?column?";

if (node.lexpr === undefined && node.rexpr !== undefined) {
const described = getDescribedNode({ alias, node: node.rexpr, context }).at(0);
const type = fmap(described, (x) => getBaseType(x.type));

if (type === null) return [];

return [{ name, type }];
}

if (node.lexpr === undefined || node.rexpr === undefined) {
return [];
}

const getResolvedNullableValueOrNull = (node: LibPgQueryAST.Node) => {
const column = getDescribedNode({ alias: undefined, node, context }).at(0);

if (column === undefined) return null;

if (column.type.kind === "array") {
return { value: "array", nullable: false };
}

if (column.type.kind === "type") {
return { value: column.type.type, nullable: false };
}

if (column.type.kind === "literal" && column.type.base.kind === "type") {
return { value: column.type.base.type, nullable: false };
}

if (column.type.kind === "union" && isTuple(column.type.value)) {
let nullable = false;
let value: string | undefined = undefined;

for (const type of column.type.value) {
if (type.kind !== "type") return null;
if (type.value === "null") nullable = true;
if (type.value !== "null") value = type.type;
}

if (value === undefined) return null;

return { value, nullable };
}

return null;
};

const lnode = getResolvedNullableValueOrNull(node.lexpr);
const rnode = getResolvedNullableValueOrNull(node.rexpr);

if (lnode === null || rnode === null) {
return [];
}

const operator = concatStringNodes(node.name);
const resolved: string | undefined =
defaultTypeExprMapping[`${lnode.value} ${operator} ${rnode.value}`];

if (resolved === undefined) {
return [];
}

return [
{
name: alias ?? "?column?",
name: name,
type: resolveType({
context: context,
nullable: false,
type: context.toTypeScriptType({ name: "boolean" }),
nullable: !context.nonNullableColumns.has(name) && (lnode.nullable || rnode.nullable),
type: context.toTypeScriptType({ name: resolved }),
}),
},
];
Expand All @@ -239,7 +326,7 @@ function getDescribedNullTest({
type: resolveType({
context: context,
nullable: false,
type: context.toTypeScriptType({ name: "boolean" }),
type: context.toTypeScriptType({ name: "bool" }),
}),
},
];
Expand Down Expand Up @@ -298,7 +385,7 @@ function getDescribedBoolExpr({
type: resolveType({
context: context,
nullable: false,
type: context.toTypeScriptType({ name: "boolean" }),
type: context.toTypeScriptType({ name: "bool" }),
}),
},
];
Expand All @@ -317,7 +404,7 @@ function getDescribedSubLink({
nullable: false,
type: (() => {
if (node.subLinkType === LibPgQueryAST.SubLinkType.EXISTS_SUBLINK) {
return context.toTypeScriptType({ name: "boolean" });
return context.toTypeScriptType({ name: "bool" });
}

return context.toTypeScriptType({ name: "unknown" });
Expand Down Expand Up @@ -412,7 +499,7 @@ function mergeDescribedColumnTypes(types: ASTDescribedColumnType[]): ASTDescribe

if (!seenSymbols.has("boolean") && seenSymbols.has("true") && seenSymbols.has("false")) {
seenSymbols.add("boolean");
result.push({ kind: "type", value: "boolean" });
result.push({ kind: "type", value: "boolean", type: "bool" });
}

if (seenSymbols.has("boolean") && (seenSymbols.has("true") || seenSymbols.has("false"))) {
Expand Down Expand Up @@ -537,15 +624,15 @@ function getDescribedFuncCallByPgFn({

const pgFnValue =
args.length === 0
? context.pgFns.get(functionName)
? (context.pgFns.get(functionName) ?? context.pgFns.get(`${functionName}(string)`))
: (context.pgFns.get(`${functionName}(${args.join(", ")})`) ??
context.pgFns.get(`${functionName}(any)`) ??
context.pgFns.get(`${functionName}(unknown)`));

const type = resolveType({
context: context,
nullable: !context.nonNullableColumns.has(name),
type: { kind: "type", value: pgFnValue ?? "unknown" },
type: { kind: "type", value: pgFnValue?.ts ?? "unknown", type: pgFnValue?.pg ?? "unknown" },
});

return [{ name, type }];
Expand Down Expand Up @@ -758,7 +845,11 @@ function getDescribedColumnByResolvedColumns(params: {
?.get(column.colName);

if (overridenType !== undefined) {
return { kind: "type", value: overridenType };
return {
kind: "type",
value: overridenType,
type: params.context.pgTypes.get(column.colTypeOid)?.name ?? "unknown",
};
}

return params.context.toTypeScriptType({
Expand Down Expand Up @@ -789,7 +880,7 @@ function getDescribedAConst({
return {
kind: "literal",
value: node.boolval.boolval ? "true" : "false",
base: context.toTypeScriptType({ name: "boolean" }),
base: context.toTypeScriptType({ name: "bool" }),
};
case node.bsval !== undefined:
return context.toTypeScriptType({ name: "bytea" });
Expand Down Expand Up @@ -838,7 +929,7 @@ function asNonNullableType(type: ASTDescribedColumnType): ASTDescribedColumnType
);

if (filtered.length === 0) {
return { kind: "type", value: "unknown" };
return { kind: "type", value: "unknown", type: "unknown" };
}

if (filtered.length === 1) {
Expand All @@ -848,7 +939,7 @@ function asNonNullableType(type: ASTDescribedColumnType): ASTDescribedColumnType
return { kind: "union", value: filtered };
}
case "type":
return type.value === "null" ? { kind: "type", value: "unknown" } : type;
return type.value === "null" ? { kind: "type", value: "unknown", type: "unknown" } : type;
}
}

Expand Down
Loading
Loading