Skip to content

Commit caa07a6

Browse files
authored
fix: prevent SQL injection in model-data API (#304)
* fix: prevent SQL injection in model-data API via identifier validation * fix: add model metadata whitelist validation to DynamicModelService * fix: resolve circular dependency between DynamicModelService and ModelServiceImpl * fix: make orderType and orderBy case-insensitive across validation layers * fix: allow queryWithPage with null/empty params for list-all queries
1 parent c5147b9 commit caa07a6

7 files changed

Lines changed: 795 additions & 238 deletions

File tree

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.tinyengine.it.common.utils;
2+
3+
import java.util.List;
4+
import java.util.regex.Pattern;
5+
6+
public class SqlIdentifierValidator {
7+
8+
private static final Pattern IDENTIFIER_PATTERN =
9+
Pattern.compile("^[a-zA-Z_][a-zA-Z0-9_]*$");
10+
11+
private static final Pattern ORDER_TYPE_PATTERN =
12+
Pattern.compile("^(ASC|DESC)$", Pattern.CASE_INSENSITIVE);
13+
14+
private SqlIdentifierValidator() {
15+
}
16+
17+
public static void validate(String identifier) {
18+
if (identifier == null || !IDENTIFIER_PATTERN.matcher(identifier).matches()) {
19+
throw new IllegalArgumentException("Invalid SQL identifier: " + identifier);
20+
}
21+
}
22+
23+
public static void validateAll(List<String> identifiers) {
24+
if (identifiers == null) {
25+
return;
26+
}
27+
identifiers.forEach(SqlIdentifierValidator::validate);
28+
}
29+
30+
public static void validateOrderType(String orderType) {
31+
if (orderType == null || !ORDER_TYPE_PATTERN.matcher(orderType).matches()) {
32+
throw new IllegalArgumentException("Invalid order type: " + orderType);
33+
}
34+
}
35+
}
Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.tinyengine.it.dynamic.dto;
22

3+
import jakarta.validation.constraints.Pattern;
34
import lombok.Data;
45

56
import java.util.List;
@@ -8,12 +9,19 @@
89
@Data
910
public class DynamicQuery {
1011

11-
private String nameEn; // 表名
12-
private String nameCh; // 表中文名
13-
private List<String> fields; // 查询字段
14-
private Map<String, Object> params; // 查询条件
15-
private Integer currentPage = 1; // 页码
16-
private Integer pageSize = 10; // 每页大小
17-
private String orderBy; // 排序字段
18-
private String orderType = "ASC"; // 排序方式
12+
@Pattern(regexp = "^[a-zA-Z_][a-zA-Z0-9_]*$", message = "模型名称格式不正确")
13+
private String nameEn;
14+
private String nameCh;
15+
private List<
16+
@Pattern(regexp = "^[a-zA-Z_][a-zA-Z0-9_]*$", message = "字段名格式不正确")
17+
String> fields;
18+
private Map<String, Object> params;
19+
private Integer currentPage = 1;
20+
private Integer pageSize = 10;
21+
22+
@Pattern(regexp = "^[a-zA-Z_][a-zA-Z0-9_]*$", message = "排序字段格式不正确")
23+
private String orderBy;
24+
25+
@Pattern(regexp = "^(?i)(ASC|DESC)$", message = "排序方式只能是 ASC 或 DESC")
26+
private String orderType = "ASC";
1927
}

base/src/main/java/com/tinyengine/it/dynamic/service/DynamicModelService.java

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
import com.tinyengine.it.dynamic.dto.DynamicUpdate;
1212
import com.tinyengine.it.model.dto.ParametersDto;
1313
import com.tinyengine.it.model.entity.Model;
14-
import lombok.RequiredArgsConstructor;
1514
import lombok.extern.slf4j.Slf4j;
1615

16+
import com.tinyengine.it.common.utils.SqlIdentifierValidator;
17+
import com.tinyengine.it.service.material.ModelService;
18+
import org.springframework.context.annotation.Lazy;
1719
import org.springframework.jdbc.core.JdbcTemplate;
1820
import org.springframework.jdbc.core.PreparedStatementCreator;
1921
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
@@ -31,12 +33,26 @@
3133

3234
@Service
3335
@Slf4j
34-
@RequiredArgsConstructor
3536
public class DynamicModelService {
3637

38+
private static final Set<String> SYSTEM_FIELDS = Set.of(
39+
"id", "created_at", "updated_at", "deleted_at", "created_by", "updated_by"
40+
);
41+
3742
private final JdbcTemplate jdbcTemplate;
3843
private final NamedParameterJdbcTemplate namedParameterJdbcTemplate;
3944
private final LoginUserContext loginUserContext;
45+
private final ModelService modelService;
46+
47+
public DynamicModelService(JdbcTemplate jdbcTemplate,
48+
NamedParameterJdbcTemplate namedParameterJdbcTemplate,
49+
LoginUserContext loginUserContext,
50+
@Lazy ModelService modelService) {
51+
this.jdbcTemplate = jdbcTemplate;
52+
this.namedParameterJdbcTemplate = namedParameterJdbcTemplate;
53+
this.loginUserContext = loginUserContext;
54+
this.modelService = modelService;
55+
}
4056

4157

4258
/**
@@ -182,6 +198,17 @@ public List<Map<String, Object>> dynamicQuery(String tableName,
182198
String orderBy,
183199
Integer limit) {
184200

201+
SqlIdentifierValidator.validate(tableName);
202+
SqlIdentifierValidator.validateAll(fields);
203+
if (conditions != null && !conditions.isEmpty()) {
204+
for (String key : conditions.keySet()) {
205+
SqlIdentifierValidator.validate(key);
206+
}
207+
}
208+
if (orderBy != null && !orderBy.isEmpty()) {
209+
SqlIdentifierValidator.validate(orderBy.replaceAll("(?i)\\s+(ASC|DESC)$", ""));
210+
}
211+
185212
// 1. 构建SQL
186213
StringBuilder sql = new StringBuilder("SELECT ");
187214

@@ -267,18 +294,19 @@ public Long count(String tableName, Map<String, Object> conditions) {
267294
* 分页查询
268295
*/
269296
public Map<String, Object> queryWithPage(DynamicQuery dto) {
270-
String tableName = getTableName( dto.getNameEn());
297+
String tableName = getTableName(dto.getNameEn());
271298
List<String> fields = dto.getFields();
272299
Map<String, Object> conditions = dto.getParams();
273300
String orderBy = dto.getOrderBy();
274301
Integer pageNum = dto.getCurrentPage();
275302
Integer pageSize = dto.getPageSize();
276303

304+
validateQueryFields(dto);
305+
277306
// 计算分页
278307
Integer limit = null;
279308
if (pageNum != null && pageSize != null) {
280309
limit = pageSize;
281-
// 如果需要偏移量,可以在这里处理
282310
}
283311

284312
// 执行查询
@@ -292,6 +320,60 @@ public Map<String, Object> queryWithPage(DynamicQuery dto) {
292320

293321
return result;
294322
}
323+
324+
private Set<String> getAllowedFields(String nameEn) {
325+
List<Model> modelList = modelService.getModelByEnName(nameEn);
326+
if (modelList == null || modelList.isEmpty()) {
327+
return Collections.emptySet();
328+
}
329+
Model model = modelList.get(0);
330+
Set<String> allowed = new HashSet<>(SYSTEM_FIELDS);
331+
if (model.getParameters() != null) {
332+
for (Object param : model.getParameters()) {
333+
String prop = extractProp(param);
334+
if (prop != null) {
335+
allowed.add(prop);
336+
}
337+
}
338+
}
339+
return allowed;
340+
}
341+
342+
@SuppressWarnings("unchecked")
343+
private String extractProp(Object param) {
344+
if (param instanceof ParametersDto) {
345+
return ((ParametersDto) param).getProp();
346+
}
347+
if (param instanceof Map) {
348+
Object value = ((Map<String, Object>) param).get("prop");
349+
return value != null ? value.toString() : null;
350+
}
351+
return null;
352+
}
353+
354+
private void validateQueryFields(DynamicQuery dto) {
355+
Set<String> allowedFields = getAllowedFields(dto.getNameEn());
356+
357+
if (dto.getFields() != null && !dto.getFields().isEmpty()) {
358+
for (String field : dto.getFields()) {
359+
SqlIdentifierValidator.validate(field);
360+
if (!allowedFields.contains(field)) {
361+
throw new IllegalArgumentException("不允许的字段: " + field);
362+
}
363+
}
364+
}
365+
366+
if (dto.getOrderBy() != null && !dto.getOrderBy().isEmpty()) {
367+
SqlIdentifierValidator.validate(dto.getOrderBy());
368+
if (!allowedFields.contains(dto.getOrderBy())) {
369+
throw new IllegalArgumentException("不允许的排序字段: " + dto.getOrderBy());
370+
}
371+
}
372+
373+
if (dto.getOrderType() != null) {
374+
SqlIdentifierValidator.validateOrderType(dto.getOrderType());
375+
}
376+
}
295377
private Object convertValueByType(Object value, String fieldType, String columnName) {
296378
try {
297379
switch (fieldType) {
@@ -525,6 +607,9 @@ public Map<String, Object> createData(DynamicInsert dataDto) {
525607

526608
String tableName = getTableName(dataDto.getNameEn());
527609
Map<String, Object> record = new HashMap<>(dataDto.getParams());
610+
for (String col : record.keySet()) {
611+
SqlIdentifierValidator.validate(col);
612+
}
528613
String userId = loginUserContext.getLoginUserId();
529614
// 添加系统字段
530615
record.put("created_by",userId);
@@ -606,6 +691,9 @@ public Map<String,Object> updateDateById(DynamicUpdate dto) {
606691
}
607692
Long id = Long.parseLong(params1.get("id").toString());
608693
Map<String, Object> updateFields = dto.getData();
694+
for (String col : updateFields.keySet()) {
695+
SqlIdentifierValidator.validate(col);
696+
}
609697
String tableName = getTableName(modelId);
610698
StringBuilder sql = new StringBuilder("UPDATE " + tableName + " SET ");
611699
List<Object> params = new ArrayList<>();

0 commit comments

Comments
 (0)