Skip to content

Commit be2879d

Browse files
committed
add adjacently tagged enum support
1 parent 5373585 commit be2879d

7 files changed

Lines changed: 583 additions & 0 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import SwiftSyntax
2+
import SwiftSyntaxMacros
3+
4+
/// Peer macro — serves only as an annotation read by `TaggedCodableMacro`.
5+
public struct CodedAtMacro: PeerMacro {
6+
public static func expansion(
7+
of _: AttributeSyntax,
8+
providingPeersOf _: some DeclSyntaxProtocol,
9+
in _: some MacroExpansionContext
10+
) throws -> [DeclSyntax] { [] }
11+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import SwiftSyntax
2+
import SwiftSyntaxMacros
3+
4+
/// Peer macro — serves only as an annotation read by `TaggedCodableMacro`.
5+
public struct ContentAtMacro: PeerMacro {
6+
public static func expansion(
7+
of _: AttributeSyntax,
8+
providingPeersOf _: some DeclSyntaxProtocol,
9+
in _: some MacroExpansionContext
10+
) throws -> [DeclSyntax] { [] }
11+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import SwiftSyntax
2+
import SwiftSyntaxMacros
3+
4+
public struct TaggedCodableMacro: ExtensionMacro {
5+
private static let expectedConformances: Set<Conformance> = [.Decodable, .Encodable]
6+
7+
public static func expansion(
8+
of node: AttributeSyntax,
9+
attachedTo declaration: some DeclGroupSyntax,
10+
providingExtensionsOf type: some TypeSyntaxProtocol,
11+
conformingTo protocols: [TypeSyntax],
12+
in context: some MacroExpansionContext
13+
) throws -> [ExtensionDeclSyntax] {
14+
try withMacro(Self.self, in: context) {
15+
let conformancesToGenerate = Conformance.makeConformances(
16+
protocols: protocols,
17+
declaration: declaration,
18+
type: type,
19+
expectedConformances: expectedConformances
20+
)
21+
return try TaggedEnumMacroBase.expansion(
22+
of: node,
23+
attachedTo: declaration,
24+
providingExtensionsOf: type,
25+
conformancesToGenerate: conformancesToGenerate,
26+
expectedConformances: expectedConformances,
27+
in: context
28+
)
29+
}
30+
}
31+
}
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
import MacroToolkit
2+
import SwiftDiagnostics
3+
import SwiftSyntax
4+
import SwiftSyntaxMacros
5+
6+
enum TaggedEnumMacroBase {
7+
8+
// MARK: - Config
9+
10+
struct Config {
11+
let tagKey: String
12+
let tagIdentifier: String // camelCase Swift identifier for CodingKeys
13+
let paramsKey: String
14+
let paramsIdentifier: String // camelCase Swift identifier for CodingKeys
15+
let caseStyle: CaseStyleTransformer
16+
}
17+
18+
struct CaseStyleTransformer {
19+
let memberName: String
20+
21+
func transform(_ caseName: String) -> String {
22+
switch memberName {
23+
case "verbatim": return caseName
24+
case "camelCase": return caseName
25+
case "snakeCase": return toSnakeCase(caseName)
26+
case "screamingSnakeCase": return toSnakeCase(caseName).uppercased()
27+
default: return toSnakeCase(caseName).uppercased()
28+
}
29+
}
30+
31+
private func toSnakeCase(_ s: String) -> String {
32+
var result = ""
33+
for (i, c) in s.enumerated() {
34+
if c.isUppercase && i > 0 { result += "_" }
35+
result += String(c).lowercased()
36+
}
37+
return result
38+
}
39+
}
40+
41+
// MARK: - Entry Point
42+
43+
static func expansion(
44+
of node: AttributeSyntax,
45+
attachedTo declaration: some DeclGroupSyntax,
46+
providingExtensionsOf type: some TypeSyntaxProtocol,
47+
conformancesToGenerate: Set<Conformance>,
48+
expectedConformances: Set<Conformance>,
49+
in context: some MacroExpansionContext
50+
) throws -> [ExtensionDeclSyntax] {
51+
52+
let checker = ConformanceDiagnosticChecker(
53+
config: .init(replacementMacroName: [:])
54+
)
55+
try checker.verify(
56+
type: type,
57+
declaration: declaration,
58+
expectedConformances: expectedConformances,
59+
conformancesToGenerate: conformancesToGenerate
60+
)
61+
62+
guard let enumDecl = Enum(declaration) else {
63+
context.diagnose(
64+
SimpleDiagnosticMessage
65+
.error(
66+
message: "'@\(MacroConfiguration.current.name)' can only be applied to an enum",
67+
diagnosticID: MessageID(domain: MacroConfiguration.current.name, id: "requiresEnum")
68+
)
69+
.diagnose(at: declaration)
70+
)
71+
return []
72+
}
73+
74+
guard let config = parseConfig(from: declaration.attributes) else {
75+
context.diagnose(
76+
SimpleDiagnosticMessage
77+
.error(
78+
message: "'@\(MacroConfiguration.current.name)' requires both '@CodedAt' and '@ContentAt'",
79+
diagnosticID: MessageID(domain: MacroConfiguration.current.name, id: "missingAnnotations")
80+
)
81+
.diagnose(at: declaration)
82+
)
83+
return []
84+
}
85+
86+
let accessModifier: AccessModifier? = enumDecl.isPublic ? .public : nil
87+
88+
let rawCode = buildExtension(
89+
typeName: type.trimmedDescription,
90+
enumDecl: enumDecl,
91+
config: config,
92+
accessModifier: accessModifier,
93+
conformances: conformancesToGenerate
94+
)
95+
96+
guard !rawCode.isEmpty else { return [] }
97+
98+
let formatted: String
99+
do {
100+
formatted = try rawCode.swiftFormatted
101+
} catch {
102+
context.diagnose(
103+
CommonDiagnostic
104+
.internalError(message: "Internal Error = \(error). Couldn't format code")
105+
.diagnose(at: declaration)
106+
)
107+
return []
108+
}
109+
110+
guard let extensionDecl = DeclSyntax(stringLiteral: formatted).as(ExtensionDeclSyntax.self) else {
111+
context.diagnose(
112+
CommonDiagnostic
113+
.internalError(message: "Internal Error. Couldn't create extension from code = \(formatted)")
114+
.diagnose(at: declaration)
115+
)
116+
return []
117+
}
118+
119+
return [extensionDecl]
120+
}
121+
122+
// MARK: - Config Parsing
123+
124+
static func parseConfig(from attributes: AttributeListSyntax) -> Config? {
125+
guard
126+
let (tagKey, transformer) = parseCodedAt(from: attributes),
127+
let paramsKey = parseContentAt(from: attributes)
128+
else { return nil }
129+
130+
return Config(
131+
tagKey: tagKey,
132+
tagIdentifier: snakeToCamel(tagKey),
133+
paramsKey: paramsKey,
134+
paramsIdentifier: snakeToCamel(paramsKey),
135+
caseStyle: transformer
136+
)
137+
}
138+
139+
private static func parseCodedAt(
140+
from attributes: AttributeListSyntax
141+
) -> (String, CaseStyleTransformer)? {
142+
guard
143+
let attr = attributes
144+
.compactMap({ $0.as(AttributeSyntax.self) })
145+
.first(where: { $0.attributeName.trimmedDescription == "CodedAt" }),
146+
case let .argumentList(args) = attr.arguments,
147+
let firstArg = args.first,
148+
let strLit = firstArg.expression.as(StringLiteralExprSyntax.self),
149+
let segment = strLit.segments.first?.as(StringSegmentSyntax.self)
150+
else { return nil }
151+
152+
let tagKey = segment.content.text
153+
154+
var styleName = "screamingSnakeCase"
155+
if let styleArg = args.first(where: { $0.label?.text == "caseStyle" }),
156+
let member = styleArg.expression.as(MemberAccessExprSyntax.self) {
157+
styleName = member.declName.baseName.text
158+
}
159+
160+
return (tagKey, CaseStyleTransformer(memberName: styleName))
161+
}
162+
163+
private static func parseContentAt(from attributes: AttributeListSyntax) -> String? {
164+
guard
165+
let attr = attributes
166+
.compactMap({ $0.as(AttributeSyntax.self) })
167+
.first(where: { $0.attributeName.trimmedDescription == "ContentAt" }),
168+
case let .argumentList(args) = attr.arguments,
169+
let firstArg = args.first,
170+
let strLit = firstArg.expression.as(StringLiteralExprSyntax.self),
171+
let segment = strLit.segments.first?.as(StringSegmentSyntax.self)
172+
else { return nil }
173+
174+
return segment.content.text
175+
}
176+
177+
private static func snakeToCamel(_ key: String) -> String {
178+
let parts = key.split(separator: "_")
179+
guard let first = parts.first else { return key }
180+
return String(first) + parts.dropFirst().map { $0.prefix(1).uppercased() + $0.dropFirst() }.joined()
181+
}
182+
183+
// MARK: - Code Generation
184+
185+
private static func buildExtension(
186+
typeName: String,
187+
enumDecl: Enum,
188+
config: Config,
189+
accessModifier: AccessModifier?,
190+
conformances: Set<Conformance>
191+
) -> String {
192+
guard !conformances.isEmpty else { return "" }
193+
194+
let conformanceList = conformances.map(\.rawValue).sorted().joined(separator: ", ")
195+
let access = accessModifier.map { "\($0.rawValue) " } ?? ""
196+
197+
var lines: [String] = []
198+
lines.append("extension \(typeName): \(conformanceList) {")
199+
200+
// Per-case CodingKeys for cases with associated values
201+
for enumCase in enumDecl.cases {
202+
guard case let .associatedValue(params) = enumCase.value else { continue }
203+
let name = perCaseCodingKeysName(enumCase.identifier)
204+
lines.append("private enum \(name): String, CodingKey {")
205+
for param in params where param.label != nil {
206+
lines.append("case \(param.label!)")
207+
}
208+
lines.append("}")
209+
}
210+
211+
// Outer CodingKeys
212+
let tagId = config.tagIdentifier
213+
let tagRaw = tagId != config.tagKey ? " = \"\(config.tagKey)\"" : ""
214+
let paramsId = config.paramsIdentifier
215+
let paramsRaw = paramsId != config.paramsKey ? " = \"\(config.paramsKey)\"" : ""
216+
217+
lines.append("\(access)enum CodingKeys: String, CodingKey, CaseIterable, Sendable, Hashable {")
218+
lines.append("case \(tagId)\(tagRaw)")
219+
lines.append("case \(paramsId)\(paramsRaw)")
220+
lines.append("}")
221+
222+
// init(from:)
223+
if conformances.contains(.Decodable) {
224+
lines += buildDecoder(enumDecl: enumDecl, config: config, tagId: tagId, paramsId: paramsId, access: access)
225+
}
226+
227+
// encode(to:)
228+
if conformances.contains(.Encodable) {
229+
lines += buildEncoder(enumDecl: enumDecl, config: config, tagId: tagId, paramsId: paramsId, access: access)
230+
}
231+
232+
lines.append("}")
233+
return lines.joined(separator: "\n")
234+
}
235+
236+
private static func buildDecoder(
237+
enumDecl: Enum,
238+
config: Config,
239+
tagId: String,
240+
paramsId: String,
241+
access: String
242+
) -> [String] {
243+
var lines: [String] = []
244+
lines.append("\(access)init(from decoder: Decoder) throws {")
245+
lines.append("let container = try decoder.container(keyedBy: CodingKeys.self)")
246+
lines.append("let tag = try container.decode(String.self, forKey: .\(tagId))")
247+
lines.append("switch tag {")
248+
249+
for enumCase in enumDecl.cases {
250+
let tagValue = config.caseStyle.transform(enumCase.identifier)
251+
lines.append("case \"\(tagValue)\":")
252+
253+
switch enumCase.value {
254+
case nil:
255+
lines.append("self = .\(enumCase.identifier)")
256+
257+
case .associatedValue(let params):
258+
let keysType = perCaseCodingKeysName(enumCase.identifier)
259+
lines.append("let params = try container.nestedContainer(keyedBy: \(keysType).self, forKey: .\(paramsId))")
260+
let argList = params.compactMap { p -> String? in
261+
guard let label = p.label else { return nil }
262+
let typeName = p.type.typeDescription(preservingOptional: false)
263+
let fn = p.type.isOptional ? "decodeIfPresent" : "decode"
264+
return "\(label): try params.\(fn)(\(typeName), forKey: .\(label))"
265+
}.joined(separator: ", ")
266+
lines.append("self = .\(enumCase.identifier)(\(argList))")
267+
268+
default:
269+
break
270+
}
271+
}
272+
273+
lines.append("default:")
274+
lines.append(
275+
"throw DecodingError.dataCorrupted(.init(codingPath: container.codingPath, debugDescription: \"Unknown \\(tag)\"))"
276+
)
277+
lines.append("}")
278+
lines.append("}")
279+
return lines
280+
}
281+
282+
private static func buildEncoder(
283+
enumDecl: Enum,
284+
config: Config,
285+
tagId: String,
286+
paramsId: String,
287+
access: String
288+
) -> [String] {
289+
var lines: [String] = []
290+
lines.append("\(access)func encode(to encoder: Encoder) throws {")
291+
lines.append("var container = encoder.container(keyedBy: CodingKeys.self)")
292+
lines.append("switch self {")
293+
294+
for enumCase in enumDecl.cases {
295+
let tagValue = config.caseStyle.transform(enumCase.identifier)
296+
297+
switch enumCase.value {
298+
case nil:
299+
lines.append("case .\(enumCase.identifier):")
300+
lines.append("try container.encode(\"\(tagValue)\", forKey: .\(tagId))")
301+
302+
case .associatedValue(let params):
303+
let keysType = perCaseCodingKeysName(enumCase.identifier)
304+
let bindings = params.compactMap(\.label).joined(separator: ", ")
305+
lines.append("case let .\(enumCase.identifier)(\(bindings)):")
306+
lines.append("try container.encode(\"\(tagValue)\", forKey: .\(tagId))")
307+
lines.append("var params = container.nestedContainer(keyedBy: \(keysType).self, forKey: .\(paramsId))")
308+
for param in params {
309+
guard let label = param.label else { continue }
310+
let fn = param.type.isOptional ? "encodeIfPresent" : "encode"
311+
lines.append("try params.\(fn)(\(label), forKey: .\(label))")
312+
}
313+
314+
default:
315+
break
316+
}
317+
}
318+
319+
lines.append("}")
320+
lines.append("}")
321+
return lines
322+
}
323+
324+
private static func perCaseCodingKeysName(_ caseName: String) -> String {
325+
caseName.prefix(1).uppercased() + caseName.dropFirst() + "CodingKeys"
326+
}
327+
}

Sources/Macro/Plugin.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
EncodableMacro.self,
2828
DecodableMacro.self,
2929

30+
// Tagged enum
31+
TaggedCodableMacro.self,
32+
CodedAtMacro.self,
33+
ContentAtMacro.self,
34+
3035
// Coding customization
3136
CodingKeyMacro.self,
3237
OmitCodingMacro.self,

0 commit comments

Comments
 (0)