Skip to content

Commit 8764f36

Browse files
committed
feat(d2): database/dialect.py - 多方言 SQL-Java 类型映射
新增 DialectAdapter 抽象基类 + MySQLDialect + PostgreSQLDialect: - MySQL: 从 template_context.py 抽出来的现有映射 (无行为变化) - PostgreSQL 12+: 新增,重点覆盖差异类型 * INT2/INT4/INT8 + SERIAL/BIGSERIAL 系列 * 原生 BOOLEAN (不走 MySQL 的 TINYINT(1) hack) * BYTEA / JSON / JSONB / UUID / INET * TIMESTAMPTZ -> OffsetDateTime (跨时区场景) * CHARACTER VARYING (VARCHAR 官方全名) - get_dialect(name) registry,未识别退回 MySQL (向后兼容) 36 个 unit test: - 11 MySQL 映射 (含 TINYINT(1) 布尔) - 13 PostgreSQL 映射 (含 SERIAL / TIMESTAMPTZ / JSONB) - 3 跨方言隔离测试 (确保 PG 没 TINYINT / MySQL 没 BYTEA) - registry / strip_paren / abstract 校验 dialect.py 100% line coverage。 Why: PostgreSQL 整数族 (INT2/INT4/INT8) + 自增 (SERIAL) + 原生类型 (UUID/ JSONB/TIMESTAMPTZ) 跟 MySQL 差异大,硬塞进 template_context 的映射会乱。 独立成模块后,Oracle/SQLServer 加进来也只是新增一个 adapter。 下一步 (D3) 用 Testcontainers 起 postgres:16,跑出来的真实类型走这套映射。
1 parent 12195f2 commit 8764f36

2 files changed

Lines changed: 506 additions & 0 deletions

File tree

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
"""
2+
P9.2: Database dialect adapters.
3+
4+
为何独立成模块:
5+
- v0.1 ~ v0.2 时期 SQL→Java 类型映射散落在 template_context.py 里,只考虑 MySQL
6+
- PostgreSQL 引入后,差异点:
7+
* 整数: INT2/INT4/INT8 vs MySQL 的 INT/BIGINT
8+
* 序列: SERIAL/BIGSERIAL (PG 自增惯例,无 AUTO_INCREMENT 关键字)
9+
* 布尔: PG 原生 BOOLEAN, MySQL 用 TINYINT(1) 模拟
10+
* 二进制: BYTEA vs BLOB 系列
11+
* JSON: JSON + JSONB (两种,PG 推荐 JSONB)
12+
* UUID: PG 原生 UUID, MySQL 没有 → CHAR(36) 存
13+
* 时区: TIMESTAMPTZ → OffsetDateTime,普通 TIMESTAMP → LocalDateTime
14+
- 把方言策略和模板上下文剥离开,后续要加 Oracle / SQLServer 时加一个 adapter 就行
15+
16+
调用方:
17+
- template_context.py (TODO 下个迭代切过去,保持向后兼容)
18+
- codegen_tools.py / atomic_codegen_tools.py
19+
- 集成测试 (PostgreSQL Testcontainers 跑出来的类型也用同一套映射)
20+
"""
21+
22+
from __future__ import annotations
23+
24+
import re
25+
from abc import ABC, abstractmethod
26+
from typing import Dict
27+
28+
29+
_PAREN_PATTERN = re.compile(r"\([^)]*\)")
30+
31+
32+
def _strip_paren(db_type: str) -> str:
33+
"""剥离 `VARCHAR(64)` 这种带长度的括号部分,统一小写无空格。"""
34+
return _PAREN_PATTERN.sub("", db_type.upper()).strip()
35+
36+
37+
# ============================================================
38+
# 基类
39+
# ============================================================
40+
41+
42+
class DialectAdapter(ABC):
43+
"""方言适配器基类。每个 DB 厂家一个子类。"""
44+
45+
name: str = "abstract"
46+
47+
@property
48+
@abstractmethod
49+
def type_to_java(self) -> Dict[str, str]:
50+
"""SQL base type → Java type. base type 已 strip 括号。"""
51+
52+
@property
53+
@abstractmethod
54+
def type_to_jdbc(self) -> Dict[str, str]:
55+
"""SQL base type → JDBC type 字符串 (用于 MyBatis Plus jdbcType)。"""
56+
57+
@property
58+
def string_types(self) -> set[str]:
59+
"""该方言下视为 String 的类型集合 (base type)。"""
60+
return {k for k, v in self.type_to_java.items() if v == "String"}
61+
62+
@property
63+
def date_types(self) -> set[str]:
64+
"""日期/时间类型 (用于决定是否 import LocalDate 等)。"""
65+
return {
66+
k
67+
for k, v in self.type_to_java.items()
68+
if v in {"LocalDate", "LocalTime", "LocalDateTime", "OffsetDateTime", "Instant"}
69+
}
70+
71+
@property
72+
def decimal_types(self) -> set[str]:
73+
return {k for k, v in self.type_to_java.items() if v == "BigDecimal"}
74+
75+
def java_type_for(self, db_type: str) -> str:
76+
"""主入口: 给一个数据库声明类型,返回 Java 类型 (找不到默认 String)。"""
77+
base = _strip_paren(db_type)
78+
# TINYINT(1) 是 MySQL 布尔约定 — 单独判
79+
if db_type.upper().replace(" ", "") == "TINYINT(1)":
80+
return self.type_to_java.get("TINYINT(1)", self.type_to_java.get(base, "String"))
81+
return self.type_to_java.get(base, "String")
82+
83+
def jdbc_type_for(self, db_type: str) -> str:
84+
base = _strip_paren(db_type)
85+
return self.type_to_jdbc.get(base, "VARCHAR")
86+
87+
def is_string_type(self, db_type: str) -> bool:
88+
return _strip_paren(db_type) in self.string_types
89+
90+
def is_date_type(self, db_type: str) -> bool:
91+
return _strip_paren(db_type) in self.date_types
92+
93+
def is_decimal_type(self, db_type: str) -> bool:
94+
return _strip_paren(db_type) in self.decimal_types
95+
96+
97+
# ============================================================
98+
# MySQL (从 template_context.py 抽出来的现有逻辑)
99+
# ============================================================
100+
101+
102+
class MySQLDialect(DialectAdapter):
103+
name = "mysql"
104+
105+
@property
106+
def type_to_java(self) -> Dict[str, str]:
107+
return {
108+
# 整数
109+
"TINYINT": "Byte",
110+
"TINYINT(1)": "Boolean", # 经典 MySQL 布尔约定
111+
"SMALLINT": "Short",
112+
"MEDIUMINT": "Integer",
113+
"INT": "Integer",
114+
"INTEGER": "Integer",
115+
"BIGINT": "Long",
116+
# 浮点
117+
"FLOAT": "Float",
118+
"DOUBLE": "Double",
119+
"DECIMAL": "BigDecimal",
120+
"NUMERIC": "BigDecimal",
121+
# 字符串
122+
"CHAR": "String",
123+
"VARCHAR": "String",
124+
"TEXT": "String",
125+
"LONGTEXT": "String",
126+
"MEDIUMTEXT": "String",
127+
"TINYTEXT": "String",
128+
"NCHAR": "String",
129+
"NVARCHAR": "String",
130+
"ENUM": "String",
131+
"SET": "String",
132+
# 日期
133+
"DATE": "LocalDate",
134+
"TIME": "LocalTime",
135+
"DATETIME": "LocalDateTime",
136+
"TIMESTAMP": "LocalDateTime",
137+
"YEAR": "Integer",
138+
# 布尔
139+
"BOOLEAN": "Boolean",
140+
"BOOL": "Boolean",
141+
# 二进制
142+
"BINARY": "byte[]",
143+
"VARBINARY": "byte[]",
144+
"BLOB": "byte[]",
145+
"LONGBLOB": "byte[]",
146+
"MEDIUMBLOB": "byte[]",
147+
"TINYBLOB": "byte[]",
148+
# JSON
149+
"JSON": "String",
150+
}
151+
152+
@property
153+
def type_to_jdbc(self) -> Dict[str, str]:
154+
return {
155+
"TINYINT": "TINYINT",
156+
"SMALLINT": "SMALLINT",
157+
"MEDIUMINT": "INTEGER",
158+
"INT": "INTEGER",
159+
"INTEGER": "INTEGER",
160+
"BIGINT": "BIGINT",
161+
"FLOAT": "FLOAT",
162+
"DOUBLE": "DOUBLE",
163+
"DECIMAL": "DECIMAL",
164+
"NUMERIC": "NUMERIC",
165+
"CHAR": "CHAR",
166+
"VARCHAR": "VARCHAR",
167+
"TEXT": "LONGVARCHAR",
168+
"LONGTEXT": "LONGVARCHAR",
169+
"MEDIUMTEXT": "LONGVARCHAR",
170+
"TINYTEXT": "VARCHAR",
171+
"NCHAR": "NCHAR",
172+
"NVARCHAR": "NVARCHAR",
173+
"DATE": "DATE",
174+
"TIME": "TIME",
175+
"DATETIME": "TIMESTAMP",
176+
"TIMESTAMP": "TIMESTAMP",
177+
"BOOLEAN": "BOOLEAN",
178+
"TINYINT(1)": "BOOLEAN",
179+
"BLOB": "BLOB",
180+
"JSON": "LONGVARCHAR",
181+
}
182+
183+
184+
# ============================================================
185+
# PostgreSQL
186+
# ============================================================
187+
188+
189+
class PostgreSQLDialect(DialectAdapter):
190+
"""PostgreSQL 12+ 方言。
191+
192+
主要差异 vs MySQL:
193+
- 整数族用 INT2/INT4/INT8 (官方别名 SMALLINT/INTEGER/BIGINT 也支持)
194+
- SERIAL/BIGSERIAL: 自增列,JDBC 端按 Integer/Long 处理
195+
- BOOLEAN 原生支持 (不需要 TINYINT(1) hack)
196+
- BYTEA 替代 BLOB
197+
- JSONB 推荐 (二进制 JSON,带索引能力);JSON 也保留
198+
- UUID 原生支持
199+
- TIMESTAMPTZ → OffsetDateTime (跨时区场景必须保留时区)
200+
- "CHARACTER VARYING" 是 VARCHAR 的官方全名
201+
"""
202+
203+
name = "postgresql"
204+
205+
@property
206+
def type_to_java(self) -> Dict[str, str]:
207+
return {
208+
# 整数
209+
"INT2": "Short",
210+
"SMALLINT": "Short",
211+
"INT4": "Integer",
212+
"INT": "Integer",
213+
"INTEGER": "Integer",
214+
"INT8": "Long",
215+
"BIGINT": "Long",
216+
# SERIAL 系列 (PG 自增惯例) - JDBC 上和普通整数一致
217+
"SMALLSERIAL": "Short",
218+
"SERIAL": "Integer",
219+
"BIGSERIAL": "Long",
220+
# 浮点
221+
"REAL": "Float",
222+
"FLOAT4": "Float",
223+
"DOUBLE PRECISION": "Double",
224+
"FLOAT8": "Double",
225+
"NUMERIC": "BigDecimal",
226+
"DECIMAL": "BigDecimal",
227+
"MONEY": "BigDecimal",
228+
# 字符串
229+
"CHAR": "String",
230+
"CHARACTER": "String",
231+
"VARCHAR": "String",
232+
"CHARACTER VARYING": "String",
233+
"TEXT": "String",
234+
"BPCHAR": "String", # internal name for CHAR
235+
# 日期
236+
"DATE": "LocalDate",
237+
"TIME": "LocalTime",
238+
"TIMETZ": "OffsetTime",
239+
"TIMESTAMP": "LocalDateTime",
240+
"TIMESTAMPTZ": "OffsetDateTime", # 跨时区必须 OffsetDateTime
241+
"TIMESTAMP WITH TIME ZONE": "OffsetDateTime",
242+
# 布尔
243+
"BOOL": "Boolean",
244+
"BOOLEAN": "Boolean",
245+
# 二进制
246+
"BYTEA": "byte[]",
247+
# JSON
248+
"JSON": "String",
249+
"JSONB": "String",
250+
# UUID
251+
"UUID": "String",
252+
# 网络
253+
"INET": "String",
254+
"CIDR": "String",
255+
# 数组占位 — Java 端建议 String,具体由模板决定
256+
"ARRAY": "String",
257+
}
258+
259+
@property
260+
def type_to_jdbc(self) -> Dict[str, str]:
261+
return {
262+
"INT2": "SMALLINT",
263+
"SMALLINT": "SMALLINT",
264+
"INT4": "INTEGER",
265+
"INT": "INTEGER",
266+
"INTEGER": "INTEGER",
267+
"INT8": "BIGINT",
268+
"BIGINT": "BIGINT",
269+
"SMALLSERIAL": "SMALLINT",
270+
"SERIAL": "INTEGER",
271+
"BIGSERIAL": "BIGINT",
272+
"REAL": "REAL",
273+
"FLOAT4": "REAL",
274+
"DOUBLE PRECISION": "DOUBLE",
275+
"FLOAT8": "DOUBLE",
276+
"NUMERIC": "NUMERIC",
277+
"DECIMAL": "DECIMAL",
278+
"MONEY": "DECIMAL",
279+
"CHAR": "CHAR",
280+
"CHARACTER": "CHAR",
281+
"BPCHAR": "CHAR",
282+
"VARCHAR": "VARCHAR",
283+
"CHARACTER VARYING": "VARCHAR",
284+
"TEXT": "LONGVARCHAR",
285+
"DATE": "DATE",
286+
"TIME": "TIME",
287+
"TIMETZ": "TIME_WITH_TIMEZONE",
288+
"TIMESTAMP": "TIMESTAMP",
289+
"TIMESTAMPTZ": "TIMESTAMP_WITH_TIMEZONE",
290+
"TIMESTAMP WITH TIME ZONE": "TIMESTAMP_WITH_TIMEZONE",
291+
"BOOL": "BOOLEAN",
292+
"BOOLEAN": "BOOLEAN",
293+
"BYTEA": "BINARY",
294+
"JSON": "OTHER",
295+
"JSONB": "OTHER",
296+
"UUID": "OTHER",
297+
"INET": "OTHER",
298+
"CIDR": "OTHER",
299+
}
300+
301+
302+
# ============================================================
303+
# Registry
304+
# ============================================================
305+
306+
307+
_REGISTRY: Dict[str, DialectAdapter] = {
308+
"mysql": MySQLDialect(),
309+
"postgresql": PostgreSQLDialect(),
310+
}
311+
312+
313+
def get_dialect(db_type: str) -> DialectAdapter:
314+
"""根据 DatabaseType 字符串 (mysql / postgresql / ...) 取适配器。
315+
316+
未识别的 db_type 退回 MySQL — 保持向后兼容,避免老调用方崩。
317+
"""
318+
return _REGISTRY.get(db_type.lower(), _REGISTRY["mysql"])
319+
320+
321+
def list_supported_dialects() -> list[str]:
322+
return sorted(_REGISTRY.keys())

0 commit comments

Comments
 (0)