|
| 1 | +import { |
| 2 | + type Mark as ProseMirrorMark, |
| 3 | + type MarkSpec, |
| 4 | + type Node as ProseMirrorNode, |
| 5 | + type NodeSpec, |
| 6 | + type ParseRule, |
| 7 | + Schema, |
| 8 | +} from "@mxm-editor/pm"; |
| 9 | +import type { Editor } from "./Editor"; |
| 10 | +import type { |
| 11 | + AnyExtension, |
| 12 | + ExtensionAttribute, |
| 13 | + GlobalAttributes, |
| 14 | + MarkConfig, |
| 15 | + NodeConfig, |
| 16 | +} from "./types"; |
| 17 | +import { cleanObject, mergeAttributes } from "./utils"; |
| 18 | + |
| 19 | +export type Extensions = AnyExtension[]; |
| 20 | + |
| 21 | +export interface ResolvedExtensionAttribute { |
| 22 | + type: string; |
| 23 | + name: string; |
| 24 | + attribute: ExtensionAttribute; |
| 25 | +} |
| 26 | + |
| 27 | +type ResolvedNodeExtension = AnyExtension & { |
| 28 | + type: "node"; |
| 29 | + config: NodeConfig<any, any>; |
| 30 | +}; |
| 31 | + |
| 32 | +type ResolvedMarkExtension = AnyExtension & { |
| 33 | + type: "mark"; |
| 34 | + config: MarkConfig<any, any>; |
| 35 | +}; |
| 36 | + |
| 37 | +const staticEditor = {} as Editor; |
| 38 | + |
| 39 | +function createStaticContext(extension: AnyExtension) { |
| 40 | + return extension.createContext(staticEditor); |
| 41 | +} |
| 42 | + |
| 43 | +function getAttributesForResolvedExtension( |
| 44 | + extension: AnyExtension, |
| 45 | + extensions: Extensions, |
| 46 | +) { |
| 47 | + const context = createStaticContext(extension); |
| 48 | + const globalAttributes = getGlobalAttributesForResolvedExtension( |
| 49 | + extension, |
| 50 | + extensions, |
| 51 | + ); |
| 52 | + |
| 53 | + return { |
| 54 | + ...globalAttributes, |
| 55 | + ...(extension.config.addAttributes?.call(context) ?? {}), |
| 56 | + } as Record<string, ExtensionAttribute>; |
| 57 | +} |
| 58 | + |
| 59 | +function getGlobalAttributesForResolvedExtension( |
| 60 | + extension: AnyExtension, |
| 61 | + extensions: Extensions, |
| 62 | +) { |
| 63 | + return extensions.reduce<Record<string, ExtensionAttribute>>( |
| 64 | + (attributes, item) => { |
| 65 | + const context = createStaticContext(item); |
| 66 | + const globalAttributes = item.config.addGlobalAttributes?.call(context) ?? []; |
| 67 | + |
| 68 | + globalAttributes.forEach((globalAttribute: GlobalAttributes) => { |
| 69 | + if (!globalAttribute.types.includes(extension.name)) { |
| 70 | + return; |
| 71 | + } |
| 72 | + |
| 73 | + Object.assign(attributes, globalAttribute.attributes); |
| 74 | + }); |
| 75 | + |
| 76 | + return attributes; |
| 77 | + }, |
| 78 | + {}, |
| 79 | + ); |
| 80 | +} |
| 81 | + |
| 82 | +function createAttributesSpec( |
| 83 | + attributes: Record<string, ExtensionAttribute>, |
| 84 | +): Record<string, { default?: any }> { |
| 85 | + return Object.fromEntries( |
| 86 | + Object.entries(attributes).map(([name, attribute]) => { |
| 87 | + const spec: { default?: any } = {}; |
| 88 | + |
| 89 | + if ("default" in attribute) { |
| 90 | + spec.default = attribute.default; |
| 91 | + } |
| 92 | + |
| 93 | + return [name, spec]; |
| 94 | + }), |
| 95 | + ); |
| 96 | +} |
| 97 | + |
| 98 | +function injectParseAttributes<T extends ParseRule>( |
| 99 | + rules: readonly T[] | undefined, |
| 100 | + attributes: Record<string, ExtensionAttribute>, |
| 101 | +) { |
| 102 | + if (!rules?.length) { |
| 103 | + return rules; |
| 104 | + } |
| 105 | + |
| 106 | + return rules.map((rule) => { |
| 107 | + const originalGetAttrs = rule.getAttrs; |
| 108 | + const staticAttrs = |
| 109 | + "attrs" in rule && rule.attrs && typeof rule.attrs === "object" |
| 110 | + ? rule.attrs |
| 111 | + : null; |
| 112 | + |
| 113 | + return { |
| 114 | + ...rule, |
| 115 | + getAttrs: (node: string | Node) => { |
| 116 | + const derivedAttrs = |
| 117 | + typeof originalGetAttrs === "function" |
| 118 | + ? (originalGetAttrs as (value: unknown) => Record<string, any> | false | null)(node) |
| 119 | + : null; |
| 120 | + |
| 121 | + const baseAttrs = { |
| 122 | + ...(staticAttrs ?? {}), |
| 123 | + ...(derivedAttrs && typeof derivedAttrs === "object" ? derivedAttrs : {}), |
| 124 | + }; |
| 125 | + |
| 126 | + if (derivedAttrs === false) { |
| 127 | + return false; |
| 128 | + } |
| 129 | + |
| 130 | + if ( |
| 131 | + typeof HTMLElement === "undefined" |
| 132 | + || !(node instanceof HTMLElement) |
| 133 | + ) { |
| 134 | + return baseAttrs; |
| 135 | + } |
| 136 | + |
| 137 | + return { |
| 138 | + ...baseAttrs, |
| 139 | + ...Object.fromEntries( |
| 140 | + Object.entries(attributes).map(([name, attribute]) => [ |
| 141 | + name, |
| 142 | + attribute.parseHTML |
| 143 | + ? attribute.parseHTML(node) |
| 144 | + : baseAttrs[name] |
| 145 | + ?? node.getAttribute(name) |
| 146 | + ?? attribute.default, |
| 147 | + ]), |
| 148 | + ), |
| 149 | + }; |
| 150 | + }, |
| 151 | + }; |
| 152 | + }) as T[]; |
| 153 | +} |
| 154 | + |
| 155 | +export function resolveExtensions(extensions: Extensions): Extensions { |
| 156 | + const resolved: AnyExtension[] = []; |
| 157 | + |
| 158 | + const visit = (items: Extensions) => { |
| 159 | + items.forEach((extension) => { |
| 160 | + resolved.push(extension); |
| 161 | + |
| 162 | + const nested = extension.config.addExtensions?.call( |
| 163 | + createStaticContext(extension), |
| 164 | + ); |
| 165 | + |
| 166 | + if (nested?.length) { |
| 167 | + visit(nested); |
| 168 | + } |
| 169 | + }); |
| 170 | + }; |
| 171 | + |
| 172 | + visit(extensions); |
| 173 | + |
| 174 | + return resolved.sort((a, b) => b.priority - a.priority); |
| 175 | +} |
| 176 | + |
| 177 | +export function splitExtensions(extensions: Extensions) { |
| 178 | + const resolved = resolveExtensions(extensions); |
| 179 | + |
| 180 | + return { |
| 181 | + nodeExtensions: resolved.filter( |
| 182 | + (extension): extension is ResolvedNodeExtension => extension.type === "node", |
| 183 | + ), |
| 184 | + markExtensions: resolved.filter( |
| 185 | + (extension): extension is ResolvedMarkExtension => extension.type === "mark", |
| 186 | + ), |
| 187 | + }; |
| 188 | +} |
| 189 | + |
| 190 | +export function getAttributesForExtension( |
| 191 | + extension: AnyExtension, |
| 192 | + extensions: Extensions, |
| 193 | +) { |
| 194 | + const resolved = resolveExtensions(extensions); |
| 195 | + |
| 196 | + return getAttributesForResolvedExtension(extension, resolved); |
| 197 | +} |
| 198 | + |
| 199 | +export function getAttributesForExtensionFromResolvedExtensions( |
| 200 | + extension: AnyExtension, |
| 201 | + extensions: Extensions, |
| 202 | +) { |
| 203 | + return getAttributesForResolvedExtension(extension, extensions); |
| 204 | +} |
| 205 | + |
| 206 | +export function getAttributesFromResolvedExtensions( |
| 207 | + extensions: Extensions, |
| 208 | +): ResolvedExtensionAttribute[] { |
| 209 | + return extensions.flatMap((extension) => |
| 210 | + Object.entries(getAttributesForResolvedExtension(extension, extensions)).map( |
| 211 | + ([name, attribute]) => ({ |
| 212 | + type: extension.name, |
| 213 | + name, |
| 214 | + attribute, |
| 215 | + }), |
| 216 | + ), |
| 217 | + ); |
| 218 | +} |
| 219 | + |
| 220 | +export function getAttributesFromExtensions( |
| 221 | + extensions: Extensions, |
| 222 | +): ResolvedExtensionAttribute[] { |
| 223 | + return getAttributesFromResolvedExtensions(resolveExtensions(extensions)); |
| 224 | +} |
| 225 | + |
| 226 | +export function getRenderedAttributes( |
| 227 | + attrs: Record<string, any>, |
| 228 | + attributes: Record<string, ExtensionAttribute>, |
| 229 | +) { |
| 230 | + return Object.entries(attributes).reduce<Record<string, string>>( |
| 231 | + (rendered, [name, attribute]) => { |
| 232 | + const value = attrs[name]; |
| 233 | + |
| 234 | + if (attribute.renderHTML) { |
| 235 | + return mergeAttributes(rendered, attribute.renderHTML(attrs)); |
| 236 | + } |
| 237 | + |
| 238 | + if (value === undefined || value === null) { |
| 239 | + return rendered; |
| 240 | + } |
| 241 | + |
| 242 | + return mergeAttributes(rendered, { |
| 243 | + [name]: String(value), |
| 244 | + }); |
| 245 | + }, |
| 246 | + {}, |
| 247 | + ); |
| 248 | +} |
| 249 | + |
| 250 | +export function getSchemaByResolvedExtensions(extensions: Extensions) { |
| 251 | + const nodes = extensions.filter( |
| 252 | + (extension): extension is ResolvedNodeExtension => extension.type === "node", |
| 253 | + ); |
| 254 | + const marks = extensions.filter( |
| 255 | + (extension): extension is ResolvedMarkExtension => extension.type === "mark", |
| 256 | + ); |
| 257 | + const topNode = nodes.find((node) => node.config.topNode)?.name; |
| 258 | + |
| 259 | + return new Schema({ |
| 260 | + topNode, |
| 261 | + nodes: Object.fromEntries( |
| 262 | + nodes.map((node) => { |
| 263 | + const context = createStaticContext(node); |
| 264 | + const attributes = getAttributesForResolvedExtension(node, extensions); |
| 265 | + const group = |
| 266 | + typeof node.config.group === "function" |
| 267 | + ? node.config.group.call(context) |
| 268 | + : node.config.group; |
| 269 | + const inline = |
| 270 | + typeof node.config.inline === "function" |
| 271 | + ? node.config.inline.call(context) |
| 272 | + : node.config.inline; |
| 273 | + const spec: NodeSpec = cleanObject({ |
| 274 | + content: node.config.content, |
| 275 | + marks: node.config.marks, |
| 276 | + group, |
| 277 | + inline, |
| 278 | + atom: node.config.atom, |
| 279 | + selectable: node.config.selectable, |
| 280 | + draggable: node.config.draggable, |
| 281 | + code: node.config.code, |
| 282 | + defining: node.config.defining, |
| 283 | + isolating: node.config.isolating, |
| 284 | + attrs: createAttributesSpec(attributes), |
| 285 | + ...(node.config.extendNodeSchema ?? {}), |
| 286 | + }); |
| 287 | + |
| 288 | + if (node.config.parseHTML) { |
| 289 | + spec.parseDOM = injectParseAttributes( |
| 290 | + node.config.parseHTML.call(context), |
| 291 | + attributes, |
| 292 | + ); |
| 293 | + } |
| 294 | + |
| 295 | + if (node.config.renderHTML) { |
| 296 | + spec.toDOM = (pmNode: ProseMirrorNode) => |
| 297 | + node.config.renderHTML!.call(context, { |
| 298 | + node: pmNode, |
| 299 | + HTMLAttributes: getRenderedAttributes(pmNode.attrs, attributes), |
| 300 | + }); |
| 301 | + } |
| 302 | + |
| 303 | + return [node.name, spec]; |
| 304 | + }), |
| 305 | + ), |
| 306 | + marks: Object.fromEntries( |
| 307 | + marks.map((mark) => { |
| 308 | + const context = createStaticContext(mark); |
| 309 | + const attributes = getAttributesForResolvedExtension(mark, extensions); |
| 310 | + const inclusive = |
| 311 | + typeof mark.config.inclusive === "function" |
| 312 | + ? mark.config.inclusive.call(context) |
| 313 | + : mark.config.inclusive; |
| 314 | + const spec: MarkSpec = cleanObject({ |
| 315 | + inclusive, |
| 316 | + excludes: mark.config.excludes, |
| 317 | + group: mark.config.group, |
| 318 | + code: mark.config.code, |
| 319 | + attrs: createAttributesSpec(attributes), |
| 320 | + }); |
| 321 | + |
| 322 | + if (mark.config.parseHTML) { |
| 323 | + spec.parseDOM = injectParseAttributes( |
| 324 | + mark.config.parseHTML.call(context), |
| 325 | + attributes, |
| 326 | + ); |
| 327 | + } |
| 328 | + |
| 329 | + if (mark.config.renderHTML) { |
| 330 | + spec.toDOM = (pmMark: ProseMirrorMark) => |
| 331 | + mark.config.renderHTML!.call(context, { |
| 332 | + mark: pmMark, |
| 333 | + HTMLAttributes: getRenderedAttributes(pmMark.attrs, attributes), |
| 334 | + }); |
| 335 | + } |
| 336 | + |
| 337 | + return [mark.name, spec]; |
| 338 | + }), |
| 339 | + ), |
| 340 | + }); |
| 341 | +} |
| 342 | + |
| 343 | +export function getSchema(extensions: Extensions) { |
| 344 | + return getSchemaByResolvedExtensions(resolveExtensions(extensions)); |
| 345 | +} |
0 commit comments