1+ import { partition } from 'lodash-es' ;
12import SymbolFactory from '@/core/analyzer/symbol/factory' ;
23import { CompileError , CompileErrorCode } from '@/core/errors' ;
34import {
45 BlockExpressionNode ,
56 ElementDeclarationNode ,
67 FunctionApplicationNode ,
8+ FunctionExpressionNode ,
9+ IdentiferStreamNode ,
710 ListExpressionNode ,
811} from '@/core/parser/nodes' ;
912import { SyntaxToken } from '@/core/lexer/tokens' ;
1013import { ElementValidator } from '@/core/analyzer/validator/types' ;
1114import SymbolTable from '@/core/analyzer/symbol/symbolTable' ;
15+ import { extractVariableFromExpression } from '@/core/analyzer/utils' ;
16+
17+ const VALID_ARG_TYPES = new Set ( [
18+ 'integer' , 'bool' , 'bytea' , 'date' , 'double_precision' ,
19+ 'float4' , 'float8' , 'int2' , 'int4' , 'int8' , 'json' , 'jsonb' ,
20+ 'numeric' , 'text' , 'time' , 'timestamp' , 'timestamptz' , 'timetz' ,
21+ 'uuid' , 'varchar' , 'vector' ,
22+ ] ) ;
23+
24+ const VALID_RETURN_TYPES = new Set ( [
25+ 'void' , 'record' , 'trigger' ,
26+ ...VALID_ARG_TYPES ,
27+ ] ) ;
28+
29+ const VALID_LANGUAGES = new Set ( [ 'plpgsql' , 'sql' , 'c' , 'internal' ] ) ;
30+ const VALID_BEHAVIORS = new Set ( [ 'volatile' , 'immutable' , 'stable' ] ) ;
31+ const VALID_SECURITIES = new Set ( [ 'invoker' , 'definer' ] ) ;
32+
33+ const SINGLE_OCCURRENCE_FIELDS = new Set ( [ 'schema' , 'returns' , 'args' , 'body' , 'language' , 'behavior' , 'security' ] ) ;
1234
1335export default class FunctionValidator implements ElementValidator {
1436 private declarationNode : ElementDeclarationNode & { type : SyntaxToken } ;
@@ -32,7 +54,7 @@ export default class FunctionValidator implements ElementValidator {
3254
3355 private validateContext ( ) : CompileError [ ] {
3456 if ( this . declarationNode . parent instanceof ElementDeclarationNode ) {
35- return [ new CompileError ( CompileErrorCode . INVALID_POLICY_CONTEXT , 'A Function can only appear top-level' , this . declarationNode ) ] ;
57+ return [ new CompileError ( CompileErrorCode . INVALID_FUNCTION_CONTEXT , 'A Function can only appear top-level' , this . declarationNode ) ] ;
3658 }
3759
3860 return [ ] ;
@@ -62,6 +84,109 @@ export default class FunctionValidator implements ElementValidator {
6284 return [ new CompileError ( CompileErrorCode . UNEXPECTED_SIMPLE_BODY , 'A Function\'s body must be a block' , body ) ] ;
6385 }
6486
65- return [ ] ;
87+ const [ fields ] = partition ( body . body , ( e ) => e instanceof FunctionApplicationNode ) ;
88+ return this . validateFields ( fields as FunctionApplicationNode [ ] ) ;
89+ }
90+
91+ private validateFields ( fields : FunctionApplicationNode [ ] ) : CompileError [ ] {
92+ const seen = new Set < string > ( ) ;
93+ return fields . flatMap ( ( field ) => {
94+ if ( ! field . callee ) return [ ] ;
95+
96+ const fieldName = extractVariableFromExpression ( field . callee ) . unwrap_or ( '' ) . toLowerCase ( ) ;
97+ const errors : CompileError [ ] = [ ] ;
98+
99+ if ( SINGLE_OCCURRENCE_FIELDS . has ( fieldName ) ) {
100+ if ( seen . has ( fieldName ) ) {
101+ errors . push ( new CompileError ( CompileErrorCode . DUPLICATE_FUNCTION_FIELD , `'${ fieldName } ' can only appear once` , field ) ) ;
102+ }
103+ seen . add ( fieldName ) ;
104+ }
105+
106+ switch ( fieldName ) {
107+ case 'schema' :
108+ break ;
109+ case 'returns' : {
110+ const value = extractVariableFromExpression ( field . args [ 0 ] ) . unwrap_or ( '' ) ;
111+ if ( ! VALID_RETURN_TYPES . has ( value . toLowerCase ( ) ) ) {
112+ errors . push ( new CompileError (
113+ CompileErrorCode . INVALID_FUNCTION_FIELD_VALUE ,
114+ `'returns' must be a valid return type (e.g. void, integer, text, ...)` ,
115+ field . args [ 0 ] || field ,
116+ ) ) ;
117+ }
118+ break ;
119+ }
120+ case 'args' : {
121+ const arg = field . args [ 0 ] ;
122+ if ( arg instanceof ListExpressionNode ) {
123+ arg . elementList . forEach ( ( attr ) => {
124+ const argType = extractVariableFromExpression ( attr . value as any ) . unwrap_or ( '' ) ;
125+ if ( argType && ! VALID_ARG_TYPES . has ( argType . toLowerCase ( ) ) ) {
126+ errors . push ( new CompileError (
127+ CompileErrorCode . INVALID_FUNCTION_FIELD_VALUE ,
128+ `Argument type '${ argType } ' is not valid` ,
129+ attr . value || attr ,
130+ ) ) ;
131+ }
132+ } ) ;
133+ }
134+ break ;
135+ }
136+ case 'body' :
137+ if ( field . args [ 0 ] && ! ( field . args [ 0 ] instanceof FunctionExpressionNode ) ) {
138+ const value = extractVariableFromExpression ( field . args [ 0 ] ) . unwrap_or ( '' ) ;
139+ if ( ! value ) {
140+ errors . push ( new CompileError (
141+ CompileErrorCode . INVALID_FUNCTION_FIELD_VALUE ,
142+ '\'body\' must be an expression or a backtick string' ,
143+ field . args [ 0 ] ,
144+ ) ) ;
145+ }
146+ }
147+ break ;
148+ case 'language' : {
149+ const value = extractVariableFromExpression ( field . args [ 0 ] ) . unwrap_or ( '' ) ;
150+ if ( ! VALID_LANGUAGES . has ( value . toLowerCase ( ) ) ) {
151+ errors . push ( new CompileError (
152+ CompileErrorCode . INVALID_FUNCTION_FIELD_VALUE ,
153+ `'language' must be one of: ${ [ ...VALID_LANGUAGES ] . join ( ', ' ) } ` ,
154+ field . args [ 0 ] || field ,
155+ ) ) ;
156+ }
157+ break ;
158+ }
159+ case 'behavior' : {
160+ const value = extractVariableFromExpression ( field . args [ 0 ] ) . unwrap_or ( '' ) ;
161+ if ( ! VALID_BEHAVIORS . has ( value . toLowerCase ( ) ) ) {
162+ errors . push ( new CompileError (
163+ CompileErrorCode . INVALID_FUNCTION_FIELD_VALUE ,
164+ `'behavior' must be one of: ${ [ ...VALID_BEHAVIORS ] . join ( ', ' ) } ` ,
165+ field . args [ 0 ] || field ,
166+ ) ) ;
167+ }
168+ break ;
169+ }
170+ case 'security' : {
171+ const value = extractVariableFromExpression ( field . args [ 0 ] ) . unwrap_or ( '' ) ;
172+ if ( ! VALID_SECURITIES . has ( value . toLowerCase ( ) ) ) {
173+ errors . push ( new CompileError (
174+ CompileErrorCode . INVALID_FUNCTION_FIELD_VALUE ,
175+ `'security' must be one of: ${ [ ...VALID_SECURITIES ] . join ( ', ' ) } ` ,
176+ field . args [ 0 ] || field ,
177+ ) ) ;
178+ }
179+ break ;
180+ }
181+ default :
182+ errors . push ( new CompileError (
183+ CompileErrorCode . UNKNOWN_FUNCTION_FIELD ,
184+ `Unknown Function field '${ fieldName } '` ,
185+ field ,
186+ ) ) ;
187+ }
188+
189+ return errors ;
190+ } ) ;
66191 }
67192}
0 commit comments