Skip to content

Commit 8f15d7f

Browse files
authored
Disallow explicit null for optional fields (microsoft#3313)
1 parent a3eef87 commit 8f15d7f

5 files changed

Lines changed: 10145 additions & 2201 deletions

File tree

internal/json/json.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func NewDecoder(r io.Reader) *jsontext.Decoder {
8080

8181
type (
8282
Value = jsontext.Value
83+
Kind = jsontext.Kind
8384
UnmarshalerFrom = json.UnmarshalerFrom
8485
MarshalerTo = json.MarshalerTo
8586
Decoder = jsontext.Decoder

internal/lsp/lsproto/_generate/generate.mts

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,40 @@ function goKindCasesForJsonKind(kind: string): string {
12471247
}
12481248
}
12491249

1250+
/**
1251+
* Checks if a meta model Type can represent a JSON null value.
1252+
* Used to determine whether to reject explicit JSON `null` for any field
1253+
* that can otherwise decode `null` without a type error.
1254+
*/
1255+
function typeCanBeNull(type: Type): boolean {
1256+
switch (type.kind) {
1257+
case "base":
1258+
return type.name === "null";
1259+
case "reference": {
1260+
const override = typeAliasOverrides.get(type.name);
1261+
if (override) {
1262+
return override.name === "any";
1263+
}
1264+
// A bare "any" reference resolves to Go's `any` (interface), which can hold null.
1265+
if (type.name === "any") {
1266+
return true;
1267+
}
1268+
if (nonResolvedAliases.has(type.name)) {
1269+
const customAlias = customTypeAliases.find(t => t.name === type.name);
1270+
if (customAlias) return typeCanBeNull(customAlias.type);
1271+
return false;
1272+
}
1273+
const aliased = typeInfo.typeAliasMap.get(type.name);
1274+
if (aliased) return typeCanBeNull(aliased);
1275+
return false;
1276+
}
1277+
case "or":
1278+
return type.items.some(item => typeCanBeNull(item));
1279+
default:
1280+
return false;
1281+
}
1282+
}
1283+
12501284
/**
12511285
* For a group of union entries that share the same JSON kind (e.g., all objects),
12521286
* find a discriminator field — a JSON property whose string literal type differs
@@ -1728,24 +1762,33 @@ function generateCode() {
17281762
if (p.omitzeroValue) return false;
17291763
return true;
17301764
}) || [];
1731-
if (requiredProps.length > 0 && structure.name !== "Registration") {
1765+
// Check if any fields need null rejection
1766+
const hasNullRejectableFields = structure.properties?.some(p => {
1767+
if (p.omitzeroValue) return false;
1768+
if (typeCanBeNull(p.type)) return false;
1769+
const resolved = resolveType(p.type);
1770+
return p.optional || resolved.needsPointer || resolved.name.startsWith("[]") || resolved.name.startsWith("map[");
1771+
}) || false;
1772+
if ((requiredProps.length > 0 || hasNullRejectableFields) && structure.name !== "Registration") {
17321773
writeLine(`\tvar _ json.UnmarshalerFrom = (*${structure.name})(nil)`);
17331774
writeLine("");
17341775

17351776
writeLine(`func (s *${structure.name}) UnmarshalJSONFrom(dec *json.Decoder) error {`);
1736-
writeLine(`\tconst (`);
1737-
for (let i = 0; i < requiredProps.length; i++) {
1738-
const prop = requiredProps[i];
1739-
const iotaPrefix = i === 0 ? " uint = 1 << iota" : "";
1740-
writeLine(`\t\tmissing${titleCase(prop.name)}${iotaPrefix}`);
1777+
if (requiredProps.length > 0) {
1778+
writeLine(`\tconst (`);
1779+
for (let i = 0; i < requiredProps.length; i++) {
1780+
const prop = requiredProps[i];
1781+
const iotaPrefix = i === 0 ? " uint = 1 << iota" : "";
1782+
writeLine(`\t\tmissing${titleCase(prop.name)}${iotaPrefix}`);
1783+
}
1784+
writeLine(`\t\t_missingLast`);
1785+
writeLine(`\t)`);
1786+
writeLine(`\tmissing := _missingLast - 1`);
1787+
writeLine("");
17411788
}
1742-
writeLine(`\t\t_missingLast`);
1743-
writeLine(`\t)`);
1744-
writeLine(`\tmissing := _missingLast - 1`);
1745-
writeLine("");
17461789

17471790
writeLine(`\tif k := dec.PeekKind(); k != '{' {`);
1748-
writeLine(`\t\treturn fmt.Errorf("expected object start, but encountered %v", k)`);
1791+
writeLine(`\t\treturn errNotObject(k)`);
17491792
writeLine(`\t}`);
17501793
writeLine(`\tif _, err := dec.ReadToken(); err != nil {`);
17511794
writeLine(`\t\treturn err`);
@@ -1764,6 +1807,15 @@ function generateCode() {
17641807
if (!prop.optional && !prop.omitzeroValue) {
17651808
writeLine(`\t\t\tmissing &^= missing${titleCase(prop.name)}`);
17661809
}
1810+
// Reject null for fields whose types cannot represent null but whose Go types
1811+
// silently accept it (pointers, slices, maps).
1812+
const resolvedType = resolveType(prop.type);
1813+
const goTypeAcceptsNull = (prop.optional || resolvedType.needsPointer || resolvedType.name.startsWith("[]") || resolvedType.name.startsWith("map[")) && !prop.omitzeroValue;
1814+
if (goTypeAcceptsNull && !typeCanBeNull(prop.type)) {
1815+
writeLine(`\t\t\tif dec.PeekKind() == 'n' {`);
1816+
writeLine(`\t\t\t\treturn errNull("${prop.name}")`);
1817+
writeLine(`\t\t\t}`);
1818+
}
17671819
writeLine(`\t\t\tif err := json.UnmarshalDecode(dec, &s.${titleCase(prop.name)}); err != nil {`);
17681820
writeLine(`\t\t\t\treturn err`);
17691821
writeLine(`\t\t\t}`);
@@ -1782,17 +1834,19 @@ function generateCode() {
17821834
writeLine(`\t}`);
17831835
writeLine("");
17841836

1785-
writeLine(`\tif missing != 0 {`);
1786-
writeLine(`\t\tvar missingProps []string`);
1787-
for (const prop of requiredProps) {
1788-
writeLine(`\t\tif missing&missing${titleCase(prop.name)} != 0 {`);
1789-
writeLine(`\t\t\tmissingProps = append(missingProps, "${prop.name}")`);
1790-
writeLine(`\t\t}`);
1837+
if (requiredProps.length > 0) {
1838+
writeLine(`\tif missing != 0 {`);
1839+
writeLine(`\t\tvar missingProps []string`);
1840+
for (const prop of requiredProps) {
1841+
writeLine(`\t\tif missing&missing${titleCase(prop.name)} != 0 {`);
1842+
writeLine(`\t\t\tmissingProps = append(missingProps, "${prop.name}")`);
1843+
writeLine(`\t\t}`);
1844+
}
1845+
writeLine(`\t\treturn errMissing(missingProps)`);
1846+
writeLine(`\t}`);
1847+
writeLine("");
17911848
}
1792-
writeLine(`\t\treturn fmt.Errorf("missing required properties: %s", strings.Join(missingProps, ", "))`);
1793-
writeLine(`\t}`);
17941849

1795-
writeLine("");
17961850
writeLine(`\treturn nil`);
17971851
writeLine(`}`);
17981852
writeLine("");
@@ -1872,7 +1926,7 @@ function generateCode() {
18721926
writeLine(`\tmissing := _missingLast - 1`);
18731927
writeLine("");
18741928
writeLine(`\tif k := dec.PeekKind(); k != '{' {`);
1875-
writeLine(`\t\treturn fmt.Errorf("expected object start, but encountered %v", k)`);
1929+
writeLine(`\t\treturn errNotObject(k)`);
18761930
writeLine(`\t}`);
18771931
writeLine(`\tif _, err := dec.ReadToken(); err != nil {`);
18781932
writeLine(`\t\treturn err`);
@@ -1922,7 +1976,7 @@ function generateCode() {
19221976
writeLine(`\t\tif missing&missingMethod != 0 {`);
19231977
writeLine(`\t\t\tmissingProps = append(missingProps, "method")`);
19241978
writeLine(`\t\t}`);
1925-
writeLine(`\t\treturn fmt.Errorf("missing required properties: %s", strings.Join(missingProps, ", "))`);
1979+
writeLine(`\t\treturn errMissing(missingProps)`);
19261980
writeLine(`\t}`);
19271981
writeLine("");
19281982
writeLine(`\tif len(rawRegisterOptions) > 0 {`);
@@ -2619,7 +2673,7 @@ function generateCode() {
26192673
}
26202674

26212675
writeLine(`\tdefault:`);
2622-
writeLine(`\t\treturn fmt.Errorf("invalid ${name}: expected ${[...(unionContainedNull ? ["null"] : []), ...kindMap.keys()].join(", ")}, got %v", dec.PeekKind())`);
2676+
writeLine(`\t\treturn errInvalidKind("${name}", dec.PeekKind())`);
26232677
writeLine(`\t}`);
26242678
}
26252679
else if (canDispatch) {
@@ -2681,13 +2735,13 @@ function generateCode() {
26812735
}
26822736
}
26832737
if (!exhaustive) {
2684-
writeLine(`\t\treturn fmt.Errorf("invalid ${name}: %s", data)`);
2738+
writeLine(`\t\treturn errInvalidValue("${name}", data)`);
26852739
}
26862740
}
26872741
}
26882742

26892743
writeLine(`\tdefault:`);
2690-
writeLine(`\t\treturn fmt.Errorf("invalid ${name}: expected ${[...(unionContainedNull ? ["null"] : []), ...kindMap.keys()].join(", ")}, got %v", dec.PeekKind())`);
2744+
writeLine(`\t\treturn errInvalidKind("${name}", dec.PeekKind())`);
26912745
writeLine(`\t}`);
26922746
}
26932747
else {
@@ -2732,7 +2786,7 @@ function generateCode() {
27322786
}
27332787
else if (!fallbackExhaustive) {
27342788
// Fallback paths: the final error references `data` which is in scope.
2735-
writeLine(`\treturn fmt.Errorf("invalid ${name}: %s", data)`);
2789+
writeLine(`\treturn errInvalidValue("${name}", data)`);
27362790
}
27372791
writeLine(`}`);
27382792
writeLine("");
@@ -2773,7 +2827,7 @@ function generateCode() {
27732827
writeLine(`\t\treturn err`);
27742828
writeLine(`\t}`);
27752829
writeLine(`\tif string(v) != \`${jsonValue}\` {`);
2776-
writeLine(`\t\treturn fmt.Errorf("expected ${name} value %s, got %s", \`${jsonValue}\`, v)`);
2830+
writeLine(`\t\treturn errLiteralMismatch("${name}", \`${jsonValue}\`, v)`);
27772831
writeLine(`\t}`);
27782832
writeLine(`\treturn nil`);
27792833
writeLine(`}`);

internal/lsp/lsproto/lsp.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,30 @@ func boolToInt(b bool) int {
122122
return 0
123123
}
124124

125+
func errNotObject(k json.Kind) error {
126+
return fmt.Errorf("expected object start, but encountered %v", k)
127+
}
128+
129+
func errNull(field string) error {
130+
return fmt.Errorf("null value is not allowed for field %q", field)
131+
}
132+
133+
func errMissing(props []string) error {
134+
return fmt.Errorf("missing required properties: %s", strings.Join(props, ", "))
135+
}
136+
137+
func errInvalidKind(typeName string, got json.Kind) error {
138+
return fmt.Errorf("invalid %s: got %v", typeName, got)
139+
}
140+
141+
func errInvalidValue(typeName string, data []byte) error {
142+
return fmt.Errorf("invalid %s: %s", typeName, data)
143+
}
144+
145+
func errLiteralMismatch(typeName string, expected string, got []byte) error {
146+
return fmt.Errorf("expected %s value %s, got %s", typeName, expected, got)
147+
}
148+
125149
func assertOnlyOne(message string, count int) {
126150
if count != 1 {
127151
panic(message)

0 commit comments

Comments
 (0)