diff --git a/Cargo.lock b/Cargo.lock index 03a73b426..0947a1d29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5103,11 +5103,13 @@ dependencies = [ "protobuf", "pyo3", "pyo3-build-config", + "regex", "serde_json", "strum", "strum_macros", "yara-x", "yara-x-fmt", + "yara-x-parser", "yara-x-proto-json", ] diff --git a/lib/src/compiler/ir/ast2ir.rs b/lib/src/compiler/ir/ast2ir.rs index e7f2a045b..2047118e8 100644 --- a/lib/src/compiler/ir/ast2ir.rs +++ b/lib/src/compiler/ir/ast2ir.rs @@ -1815,9 +1815,9 @@ fn func_call_from_ast( let expected_arg_types: Vec = if signature.method_of().is_some() { - signature.args.iter().skip(1).map(|arg| arg.ty()).collect() + signature.args.iter().skip(1).map(|(_, arg)| arg.ty()).collect() } else { - signature.args.iter().map(|arg| arg.ty()).collect() + signature.args.iter().map(|(_, arg)| arg.ty()).collect() }; if arg_types == expected_arg_types { diff --git a/lib/src/modules/mod.rs b/lib/src/modules/mod.rs index 68a6557a0..a472a49cd 100644 --- a/lib/src/modules/mod.rs +++ b/lib/src/modules/mod.rs @@ -384,7 +384,11 @@ pub mod mods { for signature in func.signatures() { signatures.push(FuncSignature { - args: signature.args.iter().map(Type::from).collect(), + args: signature + .args + .iter() + .map(|(name, ty)| (name.clone(), Type::from(ty))) + .collect(), ret: Type::from(&signature.result), description: signature.description.clone(), }); @@ -397,8 +401,8 @@ pub mod mods { /// Describes a function signature. #[derive(Clone, Debug, PartialEq)] pub struct FuncSignature { - /// The types of the function arguments. - pub args: Vec, + /// The names and types of the function arguments. + pub args: Vec<(String, Type)>, /// The return type for the function. pub ret: Type, /// Function's documentation description. diff --git a/lib/src/types/func.rs b/lib/src/types/func.rs index 0811e63f3..baa0c5cab 100644 --- a/lib/src/types/func.rs +++ b/lib/src/types/func.rs @@ -238,8 +238,7 @@ where #[derive(Clone, Serialize, Deserialize, Debug)] pub(crate) struct FuncSignature { pub mangled_name: MangledFnName, - pub args: Vec, - pub arg_names: Vec, + pub args: Vec<(String, TypeValue)>, pub result: TypeValue, pub description: Option>, } @@ -291,13 +290,11 @@ impl> From for FuncSignature { let (args_with_names, result) = mangled_name.unmangle(); let mut args = Vec::with_capacity(args_with_names.len()); - let mut arg_names = Vec::with_capacity(args_with_names.len()); for (name, ty) in args_with_names { - args.push(ty); - arg_names.push(name.to_string()); + args.push((name.to_string(), ty)); } - Self { mangled_name, args, arg_names, result, description: None } + Self { mangled_name, args, result, description: None } } } diff --git a/lib/src/wasm/mod.rs b/lib/src/wasm/mod.rs index 25e68d1fb..91a1a4657 100644 --- a/lib/src/wasm/mod.rs +++ b/lib/src/wasm/mod.rs @@ -236,14 +236,14 @@ impl WasmExport { function.add_signature(signature); } else { let mut func = Func::from(mangled_name); - // Update the description for the first and only signature in the function. + // Update the description for the first and only signature in + // the function. let signature = func.signatures_mut().get_mut(0).unwrap(); - // It's safe to get a mutable reference to the signature with Rc::get_mut - // because the Rc was just crated and there's a single reference to it. + // It's safe to get a mutable reference to the signature with + // Rc::get_mut because the Rc was just crated and there's a + // single reference to it. let signature = Rc::get_mut(signature).unwrap(); - signature.description = export.description.clone(); - functions.insert(export.name, func); } } diff --git a/ls/src/features/completion.rs b/ls/src/features/completion.rs index aa1dab19a..e82e01afd 100644 --- a/ls/src/features/completion.rs +++ b/ls/src/features/completion.rs @@ -398,25 +398,26 @@ fn field_suggestions(token: &Token) -> Option> { .signatures .iter() .map(|sig| { - let arg_types = sig + let args = sig .args .iter() - .map(ty_to_string) + .map(|(name, ty)| format!("{}: {}", name, ty_to_string(ty))) .collect::>(); - let args_template = arg_types + let args_template = sig + .args .iter() .enumerate() - .map(|(n, arg_type)| { - format!("${{{}:{arg_type}}}", n + 1) + .map(|(n, (name, _))| { + format!("${{{}:{name}}}", n + 1) }) - .join(","); + .join(", "); CompletionItem { label: format!( "{}({})", name, - arg_types.join(", ") + args.join(", ") ), kind: Some(CompletionItemKind::METHOD), insert_text: Some(format!( @@ -439,7 +440,7 @@ fn field_suggestions(token: &Token) -> Option> { name, sig.args .iter() - .map(ty_to_string) + .map(|(name, ty)| format!("{}: {}", name, ty_to_string(ty))) .join(", "), ty_to_string(&sig.ret), docs diff --git a/ls/src/features/hover.rs b/ls/src/features/hover.rs index 241f37159..2291b5a85 100644 --- a/ls/src/features/hover.rs +++ b/ls/src/features/hover.rs @@ -120,7 +120,11 @@ pub fn hover( signature .args .iter() - .map(ty_to_string) + .map(|(name, ty)| format!( + "{}: {}", + name, + ty_to_string(ty) + )) .join(", "), ty_to_string(&signature.ret), doc diff --git a/ls/src/features/signature_help.rs b/ls/src/features/signature_help.rs index d357bbd86..84ff37f4a 100644 --- a/ls/src/features/signature_help.rs +++ b/ls/src/features/signature_help.rs @@ -72,11 +72,12 @@ pub fn signature_help( let mut param_iterator = signature.args.iter(); let mut param_info = Vec::new(); - // Traverse all parameters and insert `, ` to the label, - // if the parameters is not last. - if let Some(mut curr_type) = param_iterator.next() { + if let Some((name, ty)) = param_iterator.next() { + let mut curr_name = name; + let mut curr_type = ty; loop { - let ty_str = ty_to_string(curr_type); + let ty_str = + format!("{}: {}", curr_name, ty_to_string(curr_type)); param_info.push(ParameterInformation { label: ParameterLabel::LabelOffsets([ curr_signature.len() as u32, @@ -85,8 +86,9 @@ pub fn signature_help( documentation: None, }); curr_signature.push_str(&ty_str); - if let Some(next_type) = param_iterator.next() { + if let Some((next_name, next_type)) = param_iterator.next() { curr_signature.push_str(", "); + curr_name = next_name; curr_type = next_type; } else { break; diff --git a/ls/src/tests/testdata/completion8.response.json b/ls/src/tests/testdata/completion8.response.json index bc06569f2..3559edcbb 100644 --- a/ls/src/tests/testdata/completion8.response.json +++ b/ls/src/tests/testdata/completion8.response.json @@ -1806,12 +1806,12 @@ { "documentation": { "kind": "markdown", - "value": "## `delayed_import_rva(string, string) -> integer`\n\n Returns the RVA of a delayed import where the DLL name matches\n `dll_name` and the function name matches `func_name`.\n\n Both `dll_name` and `func_name` are case-insensitive." + "value": "## `delayed_import_rva(dll_name: string, func_name: string) -> integer`\n\n Returns the RVA of a delayed import where the DLL name matches\n `dll_name` and the function name matches `func_name`.\n\n Both `dll_name` and `func_name` are case-insensitive." }, - "insertText": "delayed_import_rva(${1:string},${2:string})", + "insertText": "delayed_import_rva(${1:dll_name}, ${2:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "delayed_import_rva(string, string)", + "label": "delayed_import_rva(dll_name: string, func_name: string)", "labelDetails": { "description": "func()" } @@ -1819,12 +1819,12 @@ { "documentation": { "kind": "markdown", - "value": "## `delayed_import_rva(string, integer) -> integer`\n\n Returns the RVA of an import where the DLL name matches\n `dll_name` and the ordinal number is `ordinal`.\n\n `dll_name` is case-insensitive." + "value": "## `delayed_import_rva(dll_name: string, ordinal: integer) -> integer`\n\n Returns the RVA of an import where the DLL name matches\n `dll_name` and the ordinal number is `ordinal`.\n\n `dll_name` is case-insensitive." }, - "insertText": "delayed_import_rva(${1:string},${2:integer})", + "insertText": "delayed_import_rva(${1:dll_name}, ${2:ordinal})", "insertTextFormat": 2, "kind": 2, - "label": "delayed_import_rva(string, integer)", + "label": "delayed_import_rva(dll_name: string, ordinal: integer)", "labelDetails": { "description": "func()" } @@ -1832,12 +1832,12 @@ { "documentation": { "kind": "markdown", - "value": "## `exports(regexp) -> bool`\n\n Returns true if the PE file exports a function with a name that matches\n the given regular expression." + "value": "## `exports(func_name: regexp) -> bool`\n\n Returns true if the PE file exports a function with a name that matches\n the given regular expression." }, - "insertText": "exports(${1:regexp})", + "insertText": "exports(${1:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "exports(regexp)", + "label": "exports(func_name: regexp)", "labelDetails": { "description": "func()" } @@ -1845,12 +1845,12 @@ { "documentation": { "kind": "markdown", - "value": "## `exports(string) -> bool`\n\n Returns true if the PE file exports a function with the given name." + "value": "## `exports(func_name: string) -> bool`\n\n Returns true if the PE file exports a function with the given name." }, - "insertText": "exports(${1:string})", + "insertText": "exports(${1:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "exports(string)", + "label": "exports(func_name: string)", "labelDetails": { "description": "func()" } @@ -1858,12 +1858,12 @@ { "documentation": { "kind": "markdown", - "value": "## `exports(integer) -> bool`\n\n Returns true if the PE file exports a function with the given ordinal." + "value": "## `exports(ordinal: integer) -> bool`\n\n Returns true if the PE file exports a function with the given ordinal." }, - "insertText": "exports(${1:integer})", + "insertText": "exports(${1:ordinal})", "insertTextFormat": 2, "kind": 2, - "label": "exports(integer)", + "label": "exports(ordinal: integer)", "labelDetails": { "description": "func()" } @@ -1871,12 +1871,12 @@ { "documentation": { "kind": "markdown", - "value": "## `exports_index(regexp) -> integer`\n\n Returns true if the PE file exports a function with a name that matches\n the given regular expression." + "value": "## `exports_index(func_name: regexp) -> integer`\n\n Returns true if the PE file exports a function with a name that matches\n the given regular expression." }, - "insertText": "exports_index(${1:regexp})", + "insertText": "exports_index(${1:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "exports_index(regexp)", + "label": "exports_index(func_name: regexp)", "labelDetails": { "description": "func()" } @@ -1884,12 +1884,12 @@ { "documentation": { "kind": "markdown", - "value": "## `exports_index(string) -> integer`\n\n Returns true if the PE file exports a function with the given name." + "value": "## `exports_index(func_name: string) -> integer`\n\n Returns true if the PE file exports a function with the given name." }, - "insertText": "exports_index(${1:string})", + "insertText": "exports_index(${1:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "exports_index(string)", + "label": "exports_index(func_name: string)", "labelDetails": { "description": "func()" } @@ -1897,12 +1897,12 @@ { "documentation": { "kind": "markdown", - "value": "## `exports_index(integer) -> integer`\n\n Returns true if the PE file exports a function with the given ordinal." + "value": "## `exports_index(ordinal: integer) -> integer`\n\n Returns true if the PE file exports a function with the given ordinal." }, - "insertText": "exports_index(${1:integer})", + "insertText": "exports_index(${1:ordinal})", "insertTextFormat": 2, "kind": 2, - "label": "exports_index(integer)", + "label": "exports_index(ordinal: integer)", "labelDetails": { "description": "func()" } @@ -1923,12 +1923,12 @@ { "documentation": { "kind": "markdown", - "value": "## `import_rva(string, string) -> integer`\n\n Returns the RVA of an import where the DLL name matches\n `dll_name` and the function name matches `func_name`.\n\n Both `dll_name` and `func_name` are case-insensitive." + "value": "## `import_rva(dll_name: string, func_name: string) -> integer`\n\n Returns the RVA of an import where the DLL name matches\n `dll_name` and the function name matches `func_name`.\n\n Both `dll_name` and `func_name` are case-insensitive." }, - "insertText": "import_rva(${1:string},${2:string})", + "insertText": "import_rva(${1:dll_name}, ${2:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "import_rva(string, string)", + "label": "import_rva(dll_name: string, func_name: string)", "labelDetails": { "description": "func()" } @@ -1936,12 +1936,12 @@ { "documentation": { "kind": "markdown", - "value": "## `import_rva(string, integer) -> integer`\n\n Returns the RVA of an import where the DLL name matches\n `dll_name` and the ordinal number is `ordinal`.\n\n `dll_name` is case-insensitive." + "value": "## `import_rva(dll_name: string, ordinal: integer) -> integer`\n\n Returns the RVA of an import where the DLL name matches\n `dll_name` and the ordinal number is `ordinal`.\n\n `dll_name` is case-insensitive." }, - "insertText": "import_rva(${1:string},${2:integer})", + "insertText": "import_rva(${1:dll_name}, ${2:ordinal})", "insertTextFormat": 2, "kind": 2, - "label": "import_rva(string, integer)", + "label": "import_rva(dll_name: string, ordinal: integer)", "labelDetails": { "description": "func()" } @@ -1949,12 +1949,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(regexp, regexp) -> integer`\n\n Returns the number of imported functions where the function's name matches\n `func_name` and the DLL name matches `dll_name`.\n\n Both `dll_name` and `func_name` are case-sensitive unless you use the \"/i\"\n modifier in the regexp, as shown in the example below." + "value": "## `imports(dll_name: regexp, func_name: regexp) -> integer`\n\n Returns the number of imported functions where the function's name matches\n `func_name` and the DLL name matches `dll_name`.\n\n Both `dll_name` and `func_name` are case-sensitive unless you use the \"/i\"\n modifier in the regexp, as shown in the example below." }, - "insertText": "imports(${1:regexp},${2:regexp})", + "insertText": "imports(${1:dll_name}, ${2:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "imports(regexp, regexp)", + "label": "imports(dll_name: regexp, func_name: regexp)", "labelDetails": { "description": "func()" } @@ -1962,12 +1962,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(string, string) -> bool`\n\n Returns true if the PE imports `func_name` from `dll_name`.\n\n Both `func_name` and `dll_name` are case-insensitive." + "value": "## `imports(dll_name: string, func_name: string) -> bool`\n\n Returns true if the PE imports `func_name` from `dll_name`.\n\n Both `func_name` and `dll_name` are case-insensitive." }, - "insertText": "imports(${1:string},${2:string})", + "insertText": "imports(${1:dll_name}, ${2:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "imports(string, string)", + "label": "imports(dll_name: string, func_name: string)", "labelDetails": { "description": "func()" } @@ -1975,12 +1975,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(string, integer) -> integer`\n\n Returns true if the PE imports `ordinal` from `dll_name`.\n\n `dll_name` is case-insensitive." + "value": "## `imports(dll_name: string, ordinal: integer) -> integer`\n\n Returns true if the PE imports `ordinal` from `dll_name`.\n\n `dll_name` is case-insensitive." }, - "insertText": "imports(${1:string},${2:integer})", + "insertText": "imports(${1:dll_name}, ${2:ordinal})", "insertTextFormat": 2, "kind": 2, - "label": "imports(string, integer)", + "label": "imports(dll_name: string, ordinal: integer)", "labelDetails": { "description": "func()" } @@ -1988,12 +1988,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(string) -> integer`\n\n Returns the number of functions imported by the PE from `dll_name`.\n\n `dll_name` is case-insensitive." + "value": "## `imports(dll_name: string) -> integer`\n\n Returns the number of functions imported by the PE from `dll_name`.\n\n `dll_name` is case-insensitive." }, - "insertText": "imports(${1:string})", + "insertText": "imports(${1:dll_name})", "insertTextFormat": 2, "kind": 2, - "label": "imports(string)", + "label": "imports(dll_name: string)", "labelDetails": { "description": "func()" } @@ -2001,12 +2001,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(integer, regexp, regexp) -> integer`\n\n Returns the number of imported functions where the function's name matches\n `func_name` and the DLL name matches `dll_name`.\n\n Both `dll_name` and `func_name` are case-sensitive unless you use the \"/i\"\n modifier in the regexp, as shown in the example below. See [`imports_dll`]\n for details about the `import_flags` argument." + "value": "## `imports(import_flags: integer, dll_name: regexp, func_name: regexp) -> integer`\n\n Returns the number of imported functions where the function's name matches\n `func_name` and the DLL name matches `dll_name`.\n\n Both `dll_name` and `func_name` are case-sensitive unless you use the \"/i\"\n modifier in the regexp, as shown in the example below. See [`imports_dll`]\n for details about the `import_flags` argument." }, - "insertText": "imports(${1:integer},${2:regexp},${3:regexp})", + "insertText": "imports(${1:import_flags}, ${2:dll_name}, ${3:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "imports(integer, regexp, regexp)", + "label": "imports(import_flags: integer, dll_name: regexp, func_name: regexp)", "labelDetails": { "description": "func()" } @@ -2014,12 +2014,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(integer, string, string) -> bool`\n\n Returns true if the PE imports `func_name` from `dll_name`.\n\n Both `func_name` and `dll_name` are case-insensitive. See [`imports_dll`]\n for details about the `import_flags` argument." + "value": "## `imports(import_flags: integer, dll_name: string, func_name: string) -> bool`\n\n Returns true if the PE imports `func_name` from `dll_name`.\n\n Both `func_name` and `dll_name` are case-insensitive. See [`imports_dll`]\n for details about the `import_flags` argument." }, - "insertText": "imports(${1:integer},${2:string},${3:string})", + "insertText": "imports(${1:import_flags}, ${2:dll_name}, ${3:func_name})", "insertTextFormat": 2, "kind": 2, - "label": "imports(integer, string, string)", + "label": "imports(import_flags: integer, dll_name: string, func_name: string)", "labelDetails": { "description": "func()" } @@ -2027,12 +2027,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(integer, string, integer) -> bool`\n\n Returns true if the PE imports `ordinal` from `dll_name`.\n\n `dll_name` is case-insensitive. See [`imports_dll`] for details about\n the `import_flags` argument." + "value": "## `imports(import_flags: integer, dll_name: string, ordinal: integer) -> bool`\n\n Returns true if the PE imports `ordinal` from `dll_name`.\n\n `dll_name` is case-insensitive. See [`imports_dll`] for details about\n the `import_flags` argument." }, - "insertText": "imports(${1:integer},${2:string},${3:integer})", + "insertText": "imports(${1:import_flags}, ${2:dll_name}, ${3:ordinal})", "insertTextFormat": 2, "kind": 2, - "label": "imports(integer, string, integer)", + "label": "imports(import_flags: integer, dll_name: string, ordinal: integer)", "labelDetails": { "description": "func()" } @@ -2040,12 +2040,12 @@ { "documentation": { "kind": "markdown", - "value": "## `imports(integer, string) -> integer`\n\n Returns the number of functions imported by the PE from `dll_name`.\n\n `dll_name` is case-insensitive. `import_flags` specify the types of\n import which should be taken into account. This value can be composed\n by a bitwise OR of the following values:\n\n * `pe.IMPORT_STANDARD` : standard import only\n * `pe.IMPORT_DELAYED` : delayed imports only\n * `pe.IMPORT_ANY` : both standard and delayed imports" + "value": "## `imports(import_flags: integer, dll_name: string) -> integer`\n\n Returns the number of functions imported by the PE from `dll_name`.\n\n `dll_name` is case-insensitive. `import_flags` specify the types of\n import which should be taken into account. This value can be composed\n by a bitwise OR of the following values:\n\n * `pe.IMPORT_STANDARD` : standard import only\n * `pe.IMPORT_DELAYED` : delayed imports only\n * `pe.IMPORT_ANY` : both standard and delayed imports" }, - "insertText": "imports(${1:integer},${2:string})", + "insertText": "imports(${1:import_flags}, ${2:dll_name})", "insertTextFormat": 2, "kind": 2, - "label": "imports(integer, string)", + "label": "imports(import_flags: integer, dll_name: string)", "labelDetails": { "description": "func()" } @@ -2092,12 +2092,12 @@ { "documentation": { "kind": "markdown", - "value": "## `language(integer) -> bool`\n\n Returns true if the PE contains some resource with the specified language\n identifier.\n\n Language identifiers are the lowest 8-bit of locale identifiers and can\n be found here:\n https://learn.microsoft.com/en-us/windows-hardware/manufacture/desktop/available-language-packs-for-windows?view=windows-11" + "value": "## `language(lang: integer) -> bool`\n\n Returns true if the PE contains some resource with the specified language\n identifier.\n\n Language identifiers are the lowest 8-bit of locale identifiers and can\n be found here:\n https://learn.microsoft.com/en-us/windows-hardware/manufacture/desktop/available-language-packs-for-windows?view=windows-11" }, - "insertText": "language(${1:integer})", + "insertText": "language(${1:lang})", "insertTextFormat": 2, "kind": 2, - "label": "language(integer)", + "label": "language(lang: integer)", "labelDetails": { "description": "func()" } @@ -2105,12 +2105,12 @@ { "documentation": { "kind": "markdown", - "value": "## `locale(integer) -> bool`\n\n Returns true if the PE contains some resource with the specified locale\n identifier.\n\n Locale identifiers are 16-bit integers and can be found here:\n https://learn.microsoft.com/en-us/windows-hardware/manufacture/desktop/available-language-packs-for-windows?view=windows-11" + "value": "## `locale(loc: integer) -> bool`\n\n Returns true if the PE contains some resource with the specified locale\n identifier.\n\n Locale identifiers are 16-bit integers and can be found here:\n https://learn.microsoft.com/en-us/windows-hardware/manufacture/desktop/available-language-packs-for-windows?view=windows-11" }, - "insertText": "locale(${1:integer})", + "insertText": "locale(${1:loc})", "insertTextFormat": 2, "kind": 2, - "label": "locale(integer)", + "label": "locale(loc: integer)", "labelDetails": { "description": "func()" } @@ -2118,12 +2118,12 @@ { "documentation": { "kind": "markdown", - "value": "## `rva_to_offset(integer) -> integer`\n\n Convert a relative virtual address (RVA) to a file offset." + "value": "## `rva_to_offset(rva: integer) -> integer`\n\n Convert a relative virtual address (RVA) to a file offset." }, - "insertText": "rva_to_offset(${1:integer})", + "insertText": "rva_to_offset(${1:rva})", "insertTextFormat": 2, "kind": 2, - "label": "rva_to_offset(integer)", + "label": "rva_to_offset(rva: integer)", "labelDetails": { "description": "func()" } @@ -2131,12 +2131,12 @@ { "documentation": { "kind": "markdown", - "value": "## `section_index(string) -> integer`\n\n Returns the index in the section table of the first section with the given\n name." + "value": "## `section_index(name: string) -> integer`\n\n Returns the index in the section table of the first section with the given\n name." }, - "insertText": "section_index(${1:string})", + "insertText": "section_index(${1:name})", "insertTextFormat": 2, "kind": 2, - "label": "section_index(string)", + "label": "section_index(name: string)", "labelDetails": { "description": "func()" } @@ -2144,12 +2144,12 @@ { "documentation": { "kind": "markdown", - "value": "## `section_index(integer) -> integer`\n\n Returns the index in the section table of the first section that contains\n the given file offset." + "value": "## `section_index(offset: integer) -> integer`\n\n Returns the index in the section table of the first section that contains\n the given file offset." }, - "insertText": "section_index(${1:integer})", + "insertText": "section_index(${1:offset})", "insertTextFormat": 2, "kind": 2, - "label": "section_index(integer)", + "label": "section_index(offset: integer)", "labelDetails": { "description": "func()" } diff --git a/ls/src/tests/testdata/signature_help2.response.json b/ls/src/tests/testdata/signature_help2.response.json index c96665230..7ca470045 100644 --- a/ls/src/tests/testdata/signature_help2.response.json +++ b/ls/src/tests/testdata/signature_help2.response.json @@ -2,35 +2,35 @@ "activeParameter": 1, "signatures": [ { - "label": "delayed_import_rva(string, string) -> integer", + "label": "delayed_import_rva(dll_name: string, func_name: string) -> integer", "parameters": [ { "label": [ 19, - 25 + 35 ] }, { "label": [ - 27, - 33 + 37, + 54 ] } ] }, { - "label": "delayed_import_rva(string, integer) -> integer", + "label": "delayed_import_rva(dll_name: string, ordinal: integer) -> integer", "parameters": [ { "label": [ 19, - 25 + 35 ] }, { "label": [ - 27, - 34 + 37, + 53 ] } ] diff --git a/py/Cargo.toml b/py/Cargo.toml index c125e8c7e..828f1ed3f 100644 --- a/py/Cargo.toml +++ b/py/Cargo.toml @@ -60,6 +60,7 @@ pyo3 = { version = "0.28.2", features = [ "abi3-py38", "extension-module", ] } +regex = { workspace = true } serde_json = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } @@ -67,6 +68,7 @@ strum_macros = { workspace = true } yara-x = { workspace = true } yara-x-proto-json = { workspace = true } yara-x-fmt = { workspace = true } +yara-x-parser = { workspace = true } [build-dependencies] pyo3-build-config = "0.28.2" diff --git a/py/src/lib.rs b/py/src/lib.rs index 64699cca8..4f8b2d8e9 100644 --- a/py/src/lib.rs +++ b/py/src/lib.rs @@ -40,6 +40,7 @@ use strum_macros::{Display, EnumString}; use ::yara_x as yrx; use yara_x_fmt::Indentation; +use yara_x_parser::ast::MetaValue; fn dict_to_json(dict: Bound) -> PyResult { static JSON_DUMPS: PyOnceLock> = PyOnceLock::new(); @@ -72,6 +73,36 @@ enum SupportedModules { Dex, } +// These are copies from the checker in the CLI, but exposing them in the API +// for use here seems wrong. Maybe move them to a better place or just keep our +// own copies here? +fn is_sha256(s: &str) -> bool { + s.len() == 64 && s.chars().all(|c| c.is_ascii_hexdigit()) +} + +fn is_sha1(s: &str) -> bool { + s.len() == 40 && s.chars().all(|c| c.is_ascii_hexdigit()) +} + +fn is_md5(s: &str) -> bool { + s.len() == 32 && s.chars().all(|c| c.is_ascii_hexdigit()) +} + +/// Supported metadata types used to add linters to the compiler. +#[pyclass(from_py_object)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(clippy::upper_case_acronyms)] +enum MetaType { + STRING, + INTEGER, + FLOAT, + BOOL, + SHA256, + SHA1, + MD5, + HASH, +} + /// Formats YARA rules. #[pyclass(unsendable)] struct Formatter { @@ -420,10 +451,17 @@ impl Compiler { /// will return an InvalidRuleName warning. /// /// If the regexp does not compile a ValueError is returned. - #[pyo3(signature = (regexp))] - fn rule_name_regexp(&mut self, regexp: &str) -> PyResult<()> { - let linter = yrx::linters::rule_name(regexp) - .map_err(|err| PyValueError::new_err(err.to_string()))?; + #[pyo3(signature = (regexp, error = false))] + fn allowed_rule_name( + &mut self, + regexp: &str, + error: bool, + ) -> PyResult<()> { + let mut linter = match yrx::linters::rule_name(regexp) { + Ok(linter) => linter, + Err(err) => return Err(PyValueError::new_err(err.to_string())), + }; + linter = linter.error(error); self.inner.add_linter(linter); Ok(()) } @@ -616,6 +654,125 @@ impl Compiler { .map_err(|err| PyValueError::new_err(err.to_string()))?; json_loads.call((warnings_json,), None) } + + #[pyo3(signature = (tags, error = false))] + fn allowed_tags( + &mut self, + tags: Vec, + error: bool, + ) -> PyResult<()> { + self.inner.add_linter(yrx::linters::tags_allowed(tags).error(error)); + Ok(()) + } + + #[pyo3(signature = (regexp, error = false))] + fn allowed_tags_regex( + &mut self, + regexp: String, + error: bool, + ) -> PyResult<()> { + let mut linter = match yrx::linters::tag_regex(regexp) { + Ok(linter) => linter, + Err(err) => return Err(PyValueError::new_err(err.to_string())), + }; + linter = linter.error(error); + self.inner.add_linter(linter); + Ok(()) + } + + #[pyo3(signature = ( + identifier, + value_type, + required = false, + error = false, + regexp = None + ))] + fn allowed_metadata( + &mut self, + identifier: &str, + value_type: MetaType, + required: bool, + error: bool, + regexp: Option, + ) -> PyResult<()> { + let mut linter = + yrx::linters::metadata(identifier).required(required).error(error); + match value_type { + MetaType::STRING => { + let message = if let Some(regexp) = regexp.clone() { + let _ = regex::bytes::Regex::new(regexp.as_str()) + .map_err(|err| PyValueError::new_err(err.to_string())); + format!( + "`{identifier}` must be a string that matches `/{regexp}/`" + ) + } else { + format!("`{identifier}` must be a string") + }; + linter = linter.validator( + move |meta| match (&meta.value, ®exp) { + (MetaValue::String((s, _)), Some(regexp)) => { + regex::Regex::new(regexp.as_str()) + .unwrap() + .is_match(s) + } + (MetaValue::Bytes((s, _)), Some(regexp)) => { + regex::bytes::Regex::new(regexp.as_str()) + .unwrap() + .is_match(s) + } + (MetaValue::String(_), None) => true, + (MetaValue::Bytes(_), None) => true, + _ => false, + }, + message, + ); + } + MetaType::INTEGER => { + linter = linter.validator( + |meta| matches!(meta.value, MetaValue::Integer(_)), + format!("`{identifier}` must be an integer"), + ); + } + MetaType::FLOAT => { + linter = linter.validator( + |meta| matches!(meta.value, MetaValue::Float(_)), + format!("`{identifier}` must be a float"), + ); + } + MetaType::BOOL => { + linter = linter.validator( + |meta| matches!(meta.value, MetaValue::Bool(_)), + format!("`{identifier}` must be a bool"), + ); + } + MetaType::SHA256 => { + linter = linter.validator( + |meta| matches!(meta.value, MetaValue::String((s,_)) if is_sha256(s)), + format!("`{identifier}` must be a SHA-256"), + ); + } + MetaType::SHA1 => { + linter = linter.validator( + |meta| matches!(meta.value, MetaValue::String((s,_)) if is_sha1(s)), + format!("`{identifier}` must be a SHA-1"), + ); + } + MetaType::MD5 => { + linter = linter.validator( + |meta| matches!(meta.value, MetaValue::String((s,_)) if is_md5(s)), + format!("`{identifier}` must be a MD5"), + ); + } + MetaType::HASH => { + linter = linter.validator( + |meta| matches!(meta.value, MetaValue::String((s,_)) if is_md5(s) || is_sha1(s) || is_sha256(s)), + format!("`{identifier}` must be a MD5, SHA-1 or SHA-256"), + ); + } + } + self.inner.add_linter(linter); + Ok(()) + } } /// Optional information for the scan operation. @@ -1306,6 +1463,7 @@ fn yara_x(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.gil_used(false)?; Ok(()) } diff --git a/py/tests/test_api.py b/py/tests/test_api.py index 587162b4f..8dc10b1e0 100644 --- a/py/tests/test_api.py +++ b/py/tests/test_api.py @@ -32,7 +32,7 @@ def test_error_on_slow_pattern(): def test_invalid_rule_name_regexp(): compiler = yara_x.Compiler() with pytest.raises(ValueError): - compiler.rule_name_regexp("(AXS|ERS") + compiler.allowed_rule_name("(AXS|ERS") def test_int_globals(): @@ -397,3 +397,60 @@ def test_rules_imports(): } ''') assert rules.imports() == ["pe", "elf"] + +def test_check_allowed_tags_error(): + rule = ''' + rule test: a b c d { condition: 1 + 1 == 2} + rule test2: d { condition: 1 + 1 == 2}''' + compiler = yara_x.Compiler() + compiler.allowed_tags(['a', 'b'], error = True) + with pytest.raises(yara_x.CompileError, + match="tag `c` not in allowed list"): + compiler.add_source(rule) + # The current behavior is stop checking tags on the rule after the first tag + # fails, but subsequent rules are also checked. + errors = compiler.errors() + assert len(errors) == 2 + assert 'tag `c` not in allowed list' in errors[0]['text'] + assert 'tag `d` not in allowed list' in errors[1]['text'] + +def test_check_allowed_tags_warning(): + compiler = yara_x.Compiler() + compiler.allowed_tags(['a', 'b']) + compiler.add_source('rule test: a b c d { condition: 1 + 1 == 2}') + warnings = compiler.warnings() + assert len(warnings) == 2 + assert 'tag `c` not in allowed list' in warnings[0]['text'] + assert 'tag `d` not in allowed list' in warnings[1]['text'] + +def test_check_metadata(): + compiler = yara_x.Compiler() + compiler.allowed_metadata('a', yara_x.MetaType.STRING) + compiler.allowed_metadata('b', yara_x.MetaType.STRING, regexp='^bar') + compiler.add_source('rule test { meta: a = 1 b = "foo" condition: 1 + 1 == 2}') + warnings = compiler.warnings() + assert len(warnings) == 2 + assert '`a` must be a string' in warnings[0]['text'] + assert '`b` must be a string that matches `/^bar/`' in warnings[1]['text'] + +def test_check_rule_name_regexp(): + rule = ''' + rule test { condition: 1 + 1 == 2} + rule test2 { condition: 1 + 1 == 2}''' + compiler = yara_x.Compiler() + compiler.allowed_rule_name('^foo') + compiler.add_source(rule) + warnings = compiler.warnings() + assert len(warnings) == 2 + assert 'this rule name does not match regex `^foo`' in warnings[0]['text'] + +def test_check_rule_name_regexp_error(): + rule = ''' + rule test { condition: 1 + 1 == 2} + rule test2 { condition: 1 + 1 == 2}''' + compiler = yara_x.Compiler() + compiler.allowed_rule_name('^foo', error = True) + with pytest.raises(yara_x.CompileError, + match=r"this rule name does not match regex `\^foo`"): + compiler.add_source(rule) + assert len(compiler.errors()) == 2 \ No newline at end of file diff --git a/py/yara_x.pyi b/py/yara_x.pyi index 5067688c4..0742a1194 100644 --- a/py/yara_x.pyi +++ b/py/yara_x.pyi @@ -1,6 +1,7 @@ import collections -from typing import Any, Dict, BinaryIO, TextIO, Optional, Tuple, final +from typing import Any, Dict, BinaryIO, TextIO, Optional, Tuple, final, List +from enum import Enum class CompileError(Exception): r""" @@ -156,7 +157,7 @@ class Compiler: """ ... - def rule_name_regexp(self, regexp: str) -> None: + def allowed_rule_name(self, regexp: str, error: bool = False) -> None: r""" Tell the compiler that any rule must match this regular expression or it will result in a compiler warning. @@ -168,6 +169,24 @@ class Compiler: """ ... + def allowed_tags(self, tags: List[str], error: bool = False) -> None: + r"""List the allowed tags for rules.""" + ... + + def allowed_tags_regex(self, regexp: str, error: bool = False) -> None: + r"""A regular expression that must match all tags on rules.""" + ... + + def allowed_metadata( + self, + identifier: str, + value_type: MetaType, + required: bool, + regexp: Optional[str], + error: bool = False): + r"""Define expected type and value for metadata on rules.""" + ... + @final class ScanOptions: r""" @@ -480,3 +499,33 @@ class Module: def invoke(self, data: str) -> Any: r"""Parse the data and collect module metadata.""" ... + +@final +class MetaType(Enum): + STRING: int + INTEGER: int + FLOAT: int + BOOL: int + SHA256: int + SHA1: int + MD5: int + HASH: int + +@final +class CheckResult: + r"""Result from the [`Compiler::check`] method after checking source code.""" + def warning(self) -> bool: + r"""True if the result is a warning, false if it is an error.""" + ... + + def code(self) -> bool: + r"""The string representation of the result code.""" + ... + + def title(self) -> str: + r"""The title of the result code.""" + ... + + def message(self) -> str: + r"""A multi-line message containing code, title and full compiler details.""" + ... \ No newline at end of file