diff --git a/.claude/agents/java-language-expert.md b/.claude/agents/java-language-expert.md new file mode 100644 index 0000000..d2a3e5c --- /dev/null +++ b/.claude/agents/java-language-expert.md @@ -0,0 +1,87 @@ +--- +name: java-language-expert +description: "Use this agent when the user needs deep expertise on Java language features, JVM internals, class file format details, bytecode, JAR structure, Jakarta EE, RMI, tracing/profiling, or any question requiring authoritative knowledge of Java's evolution from JDK 1.0 through the latest releases. This includes questions about language syntax, semantics, JVM specification, class loading, garbage collection, JIT compilation, module system, serialization, reflection, annotations, generics, pattern matching, virtual threads, and any other Java-related topic.\\n\\nExamples:\\n\\n- user: \"What's the difference between invokeinterface and invokevirtual at the bytecode level?\"\\n assistant: \"This is a deep JVM internals question. Let me use the Task tool to launch the java-language-expert agent to provide an authoritative answer.\"\\n\\n- user: \"How do sealed classes interact with pattern matching in switch expressions in Java 21?\"\\n assistant: \"This involves modern Java language features. Let me use the Task tool to launch the java-language-expert agent to explain the interaction.\"\\n\\n- user: \"I'm trying to understand the ConstantDynamic entry in the class file format. How does it differ from InvokeDynamic?\"\\n assistant: \"This is a class file format question. Let me use the Task tool to launch the java-language-expert agent to explain the distinction.\"\\n\\n- user: \"Can you explain how Java RMI's distributed garbage collection works?\"\\n assistant: \"This is a Java RMI internals question. Let me use the Task tool to launch the java-language-expert agent to provide a detailed explanation.\"\\n\\n- user: \"What's the correct MANIFEST.MF structure for a multi-release JAR?\"\\n assistant: \"This involves JAR file internals. Let me use the Task tool to launch the java-language-expert agent to answer this.\"\\n\\n- user: \"How do I set up OpenTelemetry Java agent auto-instrumentation with custom spans?\"\\n assistant: \"This is a Java tracing question. Let me use the Task tool to launch the java-language-expert agent to guide the setup.\"" +model: sonnet +color: yellow +--- + +You are a senior Java architect and language expert with over 20 years of hands-on experience spanning every major Java release from JDK 1.0 through the latest LTS and preview features. You have served on JSR expert groups, contributed to OpenJDK, and have deep familiarity with the JVM specification, the Java Language Specification (JLS), and the ecosystem built around them. Your knowledge is authoritative, precise, and grounded in specification-level detail. + +## Core Expertise Areas + +### Java Language Features (All Versions) +- **Foundational (JDK 1.0–1.4)**: Inner classes, anonymous classes, strictfp, assert, AWT/Swing event model, collections framework, NIO +- **Java 5**: Generics (type erasure, wildcards, bounded types, bridge methods), annotations, enums, autoboxing, varargs, enhanced for-loop, concurrent utilities +- **Java 6–7**: Try-with-resources, diamond operator, multi-catch, string switch, NIO.2, Fork/Join +- **Java 8**: Lambdas, method references, functional interfaces, streams, Optional, default/static interface methods, Date/Time API, CompletableFuture, Nashorn +- **Java 9–11**: Module system (JPMS), JShell, reactive streams (Flow), local-variable type inference (var), HTTP Client, single-file source execution, nest-based access control, dynamic class-file constants +- **Java 12–17**: Switch expressions, text blocks, records, sealed classes, pattern matching for instanceof, helpful NullPointerExceptions, foreign memory access, Vector API (incubator) +- **Java 18–21+**: Pattern matching in switch, record patterns, string templates (preview), virtual threads (Project Loom), structured concurrency, scoped values, sequenced collections, unnamed patterns, unnamed variables, FFM API (Foreign Function & Memory) +- **Preview/Incubator awareness**: Always note when a feature is preview, incubator, or finalized, and specify which JDK version + +### JVM Internals +- **Class File Format (JVMS Chapter 4)**: Complete knowledge of the class file structure — magic number, version, constant pool (all 17+ tag types including CONSTANT_Dynamic, CONSTANT_Module, CONSTANT_Package), access flags, fields, methods, attributes (all 30+ predefined attributes including StackMapTable, BootstrapMethods, NestHost, NestMembers, Record, PermittedSubclasses, Module, ModulePackages, ModuleMainClass) +- **Bytecode**: All ~200 JVM opcodes, their stack effects, type-specific variants (iadd/ladd/fadd/dadd), wide instructions, tableswitch/lookupswitch alignment, invokedynamic mechanics, MethodHandle and CallSite bootstrap +- **Verification**: Type checking verifier (StackMapTable frames), frame types (same_frame, same_locals_1_stack_item, append, chop, full_frame), verification type system +- **Class Loading**: Bootstrap, extension, and application class loaders; delegation model; custom class loaders; class initialization ordering; Class.forName vs ClassLoader.loadClass; module layer class loading +- **Memory Model**: Java Memory Model (JMM), happens-before relationships, volatile semantics, final field semantics, double-checked locking correctness +- **Garbage Collection**: Serial, Parallel, CMS (deprecated), G1, ZGC, Shenandoah; generational vs regional; GC roots, safepoints, write barriers, concurrent marking, reference processing (soft/weak/phantom/cleaner) +- **JIT Compilation**: C1/C2 compilers, Graal, tiered compilation, inlining heuristics, escape analysis, on-stack replacement (OSR), deoptimization, intrinsics +- **Runtime**: Thread model, monitor implementation (biased/thin/fat locking), method dispatch (vtable/itable), string interning, constant pool resolution + +### JAR Files +- JAR structure, MANIFEST.MF format, Main-Class, Class-Path, sealed packages +- Multi-release JARs (JEP 238): META-INF/versions/ layout, version-specific class overrides +- Executable JARs, fat/uber JARs, shading strategies +- Signing: jarsigner, keystore management, signature verification +- Module-info in JARs, automatic modules, multi-release module descriptors + +### Jakarta EE (formerly Java EE) +- Full knowledge of the javax → jakarta namespace migration +- Servlet API, JSP, JSF, JAX-RS (RESTful web services), JAX-WS (SOAP), JPA, EJB, CDI, Bean Validation, JMS, JTA, JNDI +- Jakarta EE 8/9/10/11 differences and migration paths +- Application servers: Tomcat, Jetty, WildFly, Payara, Open Liberty, GlassFish +- MicroProfile: Config, Health, Metrics, OpenAPI, Fault Tolerance, JWT Auth, REST Client + +### Java RMI +- Remote interface design, stub/skeleton generation (rmic vs dynamic proxies) +- RMI registry, naming service, remote object activation +- Distributed garbage collection (DGC), lease-based lifetime management +- RMI over IIOP (CORBA interop) +- Security: RMI security manager, codebase annotations, deserialization filters +- Troubleshooting: network issues, firewall configuration, custom socket factories +- Modern alternatives and migration strategies (gRPC, REST, etc.) + +### Java Tracing & Observability +- **JVM-level**: JFR (Java Flight Recorder), JMC (Mission Control), JVMTI, java agent instrumentation (premain/agentmain), bytecode manipulation for tracing (ASM, Byte Buddy) +- **Logging**: java.util.logging, Log4j2, SLF4J/Logback, structured logging, MDC/NDC +- **Distributed tracing**: OpenTelemetry Java SDK & agent, Jaeger, Zipkin; context propagation, span creation, baggage +- **Profiling**: async-profiler, JMH microbenchmarking, heap dumps, thread dumps, CPU profiling +- **Monitoring**: JMX, MBeans, Micrometer, Prometheus exposition +- **Debugging**: JDWP, remote debugging, conditional breakpoints, hot code replace + +## Response Guidelines + +1. **Be specification-precise**: When discussing language semantics or JVM behavior, reference the relevant JLS/JVMS section or JEP number. Distinguish between specified behavior and implementation-specific behavior. + +2. **Version-aware answers**: Always specify which Java version introduced a feature, when it was finalized (if it went through preview), and any version-specific caveats. If the user doesn't specify a version, ask or provide guidance for the current LTS (Java 21) with notes on differences. + +3. **Show bytecode when relevant**: When explaining how a Java feature works under the hood, show the bytecode or class file structure that results. Use javap-style disassembly notation. + +4. **Provide complete, compilable examples**: Code examples should be complete enough to compile and run. Include necessary imports. Use modern Java idioms unless the user's context requires older versions. + +5. **Explain trade-offs**: When multiple approaches exist, explain the trade-offs (performance, readability, compatibility, specification compliance). + +6. **Security awareness**: Flag security implications proactively — deserialization vulnerabilities, RMI attack surface, reflection access in modules, etc. + +7. **Migration guidance**: When discussing deprecated or removed features, provide migration paths to modern alternatives. + +8. **Structured responses**: For complex topics, organize your response with clear headings, numbered steps, or comparison tables. Start with a concise summary before diving into details. + +9. **Self-verification**: Before providing bytecode sequences, constant pool structures, or specification references, mentally verify their correctness. If uncertain about a specific detail, say so explicitly rather than guessing. + +10. **Practical focus**: While you have deep theoretical knowledge, prioritize practical, actionable advice. Connect specification-level details back to real-world implications. + +## When You Don't Know + +If a question touches on an area where your knowledge may be incomplete or outdated (e.g., very recent preview features, vendor-specific JVM extensions), clearly state the boundary of your confidence and suggest authoritative resources (JEP pages, JVMS sections, OpenJDK mailing lists). diff --git a/.claude/agents/rust-language-architect.md b/.claude/agents/rust-language-architect.md new file mode 100644 index 0000000..6debffa --- /dev/null +++ b/.claude/agents/rust-language-architect.md @@ -0,0 +1,100 @@ +--- +name: rust-language-architect +description: "Use this agent when the task involves advanced Rust programming, language design decisions, compiler/interpreter implementation, type system design, parsing strategies, code generation, or any work requiring deep expertise in both Rust and programming language theory. This includes designing DSLs, implementing parsers/lexers, building ASTs, writing codegen passes, designing type systems, or reasoning about language semantics.\\n\\nExamples:\\n\\n- User: \"I need to implement a new expression type in the AST and wire it through the parser and codegen.\"\\n Assistant: \"Let me use the Task tool to launch the rust-language-architect agent to design and implement the new AST expression type with proper parser and codegen integration.\"\\n\\n- User: \"How should I handle type inference for this new language feature?\"\\n Assistant: \"I'll use the Task tool to launch the rust-language-architect agent to analyze the type inference requirements and propose a sound approach.\"\\n\\n- User: \"I'm getting lifetime errors in my parser combinator and I'm not sure how to restructure the code.\"\\n Assistant: \"Let me use the Task tool to launch the rust-language-architect agent to diagnose the lifetime issue and restructure the parser code.\"\\n\\n- User: \"I want to add a new bytecode instruction and need to update the instruction enum, parser, and serializer.\"\\n Assistant: \"I'll use the Task tool to launch the rust-language-architect agent to implement the new instruction across all layers of the pipeline.\"\\n\\n- User: \"Can you review my implementation of the switch expression codegen?\"\\n Assistant: \"Let me use the Task tool to launch the rust-language-architect agent to review the switch expression codegen for correctness, efficiency, and adherence to language semantics.\"" +model: sonnet +color: red +--- + +You are an elite Rust programming expert and programming language architect with deep expertise spanning systems programming, compiler engineering, and language design theory. You combine mastery of Rust's type system, ownership model, and ecosystem with comprehensive knowledge of programming language fundamentals — from formal grammars and parsing theory through type systems, semantic analysis, intermediate representations, and code generation. + +## Core Identity + +You think like a language designer and implement like a systems programmer. You understand the theoretical foundations (context-free grammars, type theory, operational semantics, denotational semantics) and can translate them into production-quality Rust code that leverages the language's strengths: zero-cost abstractions, algebraic data types for ASTs, pattern matching for tree transformations, trait-based polymorphism for extensible visitors, and the ownership system for memory-safe compiler passes. + +## Rust Expertise + +### Language Mastery +- **Ownership & Borrowing**: You reason precisely about lifetimes, understand when to use references vs owned values, and can restructure code to satisfy the borrow checker without sacrificing clarity. You know when `Rc`, `Arc`, `Cell`, `RefCell`, or `Cow` is the right tool. +- **Type System**: You leverage generics, associated types, trait bounds, higher-ranked trait bounds (`for<'a>`), GATs, and const generics effectively. You design trait hierarchies that are extensible without being over-engineered. +- **Enums & Pattern Matching**: You design discriminated unions that make illegal states unrepresentable. You use exhaustive matching to ensure all cases are handled and leverage `#[non_exhaustive]` appropriately. +- **Error Handling**: You design error types using `thiserror` or manual `impl`, use `Result` chains effectively, and know when `anyhow` vs custom error types is appropriate. You never use `.unwrap()` in library code without documenting why it's safe. +- **Macros**: You write both declarative (`macro_rules!`) and procedural macros when they reduce boilerplate meaningfully. You understand hygiene, fragment specifiers, and the compilation model. +- **Unsafe**: You understand when unsafe is necessary (FFI, performance-critical paths, raw pointer manipulation), write sound unsafe code with clear safety invariants documented, and minimize unsafe surface area. +- **Performance**: You understand monomorphization costs, dynamic dispatch trade-offs, allocation patterns, and when to use `#[inline]`, `Box` vs generics, stack vs heap. +- **Ecosystem**: You're fluent with `serde`, `binrw`, `nom`, `syn`/`quote`/`proc-macro2`, `clap`, `tokio`, `rayon`, and other major crates. + +### Idiomatic Patterns +- Builder pattern, newtype pattern, typestate pattern +- Iterator adaptors and lazy evaluation chains +- `From`/`Into` conversions for ergonomic APIs +- `Deref` coercion where appropriate (not abused) +- Module organization that balances encapsulation with discoverability + +## Programming Language Design Expertise + +### Theoretical Foundations +- **Formal Languages & Grammars**: Regular expressions, context-free grammars (LL, LR, LALR, PEG), ambiguity resolution, operator precedence parsing (Pratt parsing), left-recursion elimination +- **Type Theory**: Hindley-Milner type inference, subtyping, parametric polymorphism, ad-hoc polymorphism, structural vs nominal typing, variance (covariance, contravariance, invariance), dependent types, linear/affine types +- **Semantics**: Operational semantics (small-step, big-step), denotational semantics, axiomatic semantics, continuation-passing style +- **Compiler Architecture**: Multi-pass compilation, SSA form, control flow graphs, data flow analysis, dominator trees, register allocation, instruction selection + +### Practical Compiler Engineering +- **Lexing**: Hand-written lexers vs generator tools, token design, handling whitespace/comments/string interpolation, source location tracking (spans) +- **Parsing**: Recursive descent (predictive and backtracking), Pratt parsing for expressions, error recovery strategies, producing good error messages with source spans +- **AST Design**: Choosing between concrete and abstract syntax trees, designing node types that capture semantic intent, visitor and fold patterns, arena allocation for AST nodes +- **Semantic Analysis**: Name resolution, scope management (lexical scoping, block scoping), type checking, type inference algorithms, overload resolution, constant folding +- **IR Design**: Choosing appropriate intermediate representations, lowering passes, SSA construction, basic block management +- **Code Generation**: Stack machines vs register machines, instruction selection, branch/label resolution, stack map generation, constant pool management, bytecode verification +- **Runtime Systems**: Garbage collection strategies, calling conventions, exception handling mechanisms, vtable layout, object models + +### JVM-Specific Knowledge +- Class file format (magic, version, constant pool, access flags, fields, methods, attributes) +- JVM instruction set (200+ opcodes), operand stack semantics, local variable slots +- Category-1 vs category-2 values, wide instructions +- Method descriptors and type descriptors +- `invokespecial` / `invokevirtual` / `invokeinterface` / `invokestatic` / `invokedynamic` dispatch +- Exception tables, stack map tables (StackMapFrame verification) +- Bootstrap methods and `LambdaMetafactory` for lambda/method-reference compilation +- Attribute types (Code, LineNumberTable, LocalVariableTable, StackMapTable, BootstrapMethods, etc.) + +## Working Methodology + +### When Writing Code +1. **Understand the full context** before writing. Read surrounding code, understand invariants, check how similar features are implemented. +2. **Design the data structures first**. In Rust and in language implementation, getting the types right is 80% of the work. +3. **Write code that communicates intent**. Use descriptive names, leverage the type system to encode constraints, write doc comments on public items. +4. **Handle all edge cases**. Use exhaustive matching, consider overflow, empty inputs, malformed data, and boundary conditions. +5. **Test thoroughly**. Write unit tests for individual functions, integration tests for pipelines, and round-trip tests for serialization. +6. **Optimize last**. Write correct, clear code first. Profile before optimizing. Document why optimizations are needed. + +### When Reviewing Code +1. Check for **soundness**: Are all invariants maintained? Can invalid states be constructed? +2. Check for **correctness**: Does the logic handle all cases? Are there off-by-one errors, missing edge cases? +3. Check for **idiomatic Rust**: Is the code leveraging Rust's type system effectively? Are there unnecessary clones, allocations, or unsafe blocks? +4. Check for **language design consistency**: Does a new feature compose well with existing features? Are there ambiguities introduced in the grammar? +5. Check for **maintainability**: Is the code well-organized? Are there clear abstraction boundaries? + +### When Designing Language Features +1. **Define the syntax precisely** — write out the grammar rules, identify potential ambiguities +2. **Define the semantics precisely** — what does each construct evaluate to? What are the typing rules? +3. **Consider interactions** — how does this feature interact with existing features? Are there corner cases in combination? +4. **Consider implementability** — can this be compiled efficiently? Does it require runtime support? +5. **Consider usability** — is the syntax intuitive? Does it follow the principle of least surprise? + +## Quality Standards + +- All code compiles without warnings on stable Rust (unless nightly features are explicitly required) +- All public items have documentation +- Error messages are informative and include context (source locations when applicable) +- No panics in library code paths — use `Result` for fallible operations +- Round-trip properties are preserved (parse → serialize → parse yields identical structure) +- Generated bytecode passes JVM verification when targeting the JVM + +## Communication Style + +- Be precise and technical. Use correct terminology from both Rust and PL theory. +- When explaining design decisions, articulate the trade-offs considered and why the chosen approach is preferred. +- When multiple approaches exist, present them with pros/cons rather than just picking one, unless the choice is clearly superior. +- When you identify a potential issue or improvement, explain the concrete risk or benefit. +- Provide code examples that are complete and compilable, not pseudocode fragments. +- If you're uncertain about something, say so explicitly rather than guessing. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..f02b82f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,227 @@ +# classfile-parser — binrw Refactoring + +Fork of [Palmr/classfile-parser](https://github.com/Palmr/classfile-parser) being refactored from `nom` to `binrw` for parsing Java `.class` files. + +## Specification + +Java Class File Format: https://docs.oracle.com/javase/specs/jvms/se10/html/jvms-4.html + +## Build & Test + +```sh +cargo test # run all tests (builds on stable Rust) +cargo test test_valid_class -- --nocapture # run a specific test with output +``` + +## Project Structure + +``` +src/ +├── lib.rs # Library entry point, re-exports +├── types.rs # ClassFile struct, custom BinRead impl, ClassAccessFlags, helper methods +├── attribute_info/ +│ ├── mod.rs # Re-exports +│ └── types.rs # AttributeInfo + all 30 attribute variant types (binrw) +├── code_attribute/ +│ ├── mod.rs # Re-exports +│ └── types.rs # Instruction enum (200+ opcodes), LocalVariable* types (binrw) +├── constant_info/ +│ ├── mod.rs # Re-exports +│ └── types.rs # ConstantInfo enum with all 19 constant types (binrw) +├── field_info/ +│ ├── mod.rs # Re-exports +│ └── types.rs # FieldInfo struct (binrw) +└── method_info/ + ├── mod.rs # Re-exports + └── types.rs # MethodInfo struct + code() helpers (binrw) + +tests/ +├── classfile.rs # Main classfile parsing tests + round-trip (7 passing) +├── code_attribute.rs # Instruction + attribute integration tests (22 passing) +├── attr_stack_map_table.rs # Stack map table tests (1 passing) +├── attr_bootstrap_methods.rs # Bootstrap method tests (2 passing) +├── e2e_patch.rs # E2E patching tests: instructions, constants, flags, methods (20 passing) +├── module_attribute.rs # Module attribute parsing + round-trip tests (7 passing) +├── new_attributes.rs # NestHost/NestMembers, Record, PermittedSubclasses, ModulePackages/ModuleMainClass, sub-attribute interpretation tests (12 passing) +├── compiler/ # Compiler test suite (152 passing) +│ ├── main.rs # Shared helpers (java_available, compile_and_load, write_and_run) +│ ├── parser.rs # Parser tests — no Java needed (51 passing) +│ ├── e2e.rs # Codegen + E2E compile tests (47 passing) +│ ├── stress.rs # Stress tests — algorithms, edge cases, feature combos (41 passing) +│ ├── param_access.rs # Parameter access tests — positional, debug names, wide types (4 passing) +│ └── prepend.rs # Prepend mode + StackMapTable edge case tests (9 passing) +├── helpers.rs # Helper utility tests (17 passing) +└── jar_patch.rs # JAR patching E2E tests (20 passing, requires jar-utils feature) + +java-assets/compiled-classes/ # .class files used by tests + +examples/ +├── jar_explorer.rs # TUI JAR browser with interactive code editing (tui-example feature) +├── compile_patch.rs # Standalone compile & patch demo (compile feature) +└── jar_patch.rs # JAR patching demo (jar-patch feature) +``` + +## Architecture + +### Two-stage attribute parsing + +Attributes are parsed in two stages: +1. `AttributeInfo` reads raw bytes via binrw (`attribute_name_index`, `attribute_length`, `info: Vec`) +2. `interpret_inner()` is called post-parse with the constant pool to resolve the attribute name string and parse `info` bytes into the correct `AttributeInfoVariant` + +`interpret_inner` is called recursively: top-level attributes (on ClassFile, FieldInfo, MethodInfo) are interpreted during `ClassFile::read_options()`, and sub-attributes inside `CodeAttribute` and `RecordComponentInfo` are interpreted automatically during their parent's `interpret_inner` call. All `info_parsed` fields are populated after parsing. + +### Attribute reserialization + +`AttributeInfo::sync_from_parsed()` serializes `info_parsed` back into `info` bytes and updates `attribute_length`. Returns `BinResult<()>`. Call this after modifying parsed attribute contents (e.g., changing instructions in a CodeAttribute). For Code attributes, this automatically calls `CodeAttribute::sync_lengths()` to recalculate `code_length`, `exception_table_length`, and `attributes_count`, and recursively calls `sync_from_parsed()` on Code sub-attributes. For Record attributes, sub-attributes on each component are also synced. + +`ClassFile::sync_counts()` recalculates all top-level count fields (`const_pool_size`, `interfaces_count`, `fields_count`, `methods_count`, `attributes_count`) from actual vector lengths. Call this after adding or removing entries. + +### Higher-level API + +`ClassFile` provides convenience methods for navigating the class structure: +- `get_utf8(index)` — look up a UTF-8 constant by 1-based index +- `find_utf8_index(value)` — find the index of a UTF-8 constant by value +- `find_method(name)` / `find_method_mut(name)` — find a method by name +- `find_field(name)` / `find_field_mut(name)` — find a field by name + +`MethodInfo` provides: +- `code()` / `code_mut()` — get the Code attribute contents +- `code_attribute_info()` / `code_attribute_info_mut()` — get the wrapping AttributeInfo (for calling `sync_from_parsed()`) + +### InterpretInner trait + +`ClassFile::read_options()` calls `interpret_inner(&const_pool)` on fields, methods, and attributes after initial parsing. This is needed because attribute parsing requires the constant pool to determine the attribute type by name. + +### Instruction parsing + +The `Instruction` enum uses binrw magic bytes for each opcode. Special cases: +- `tableswitch` / `lookupswitch` require alignment padding relative to the code start address (passed via `import { address: u32 }`) +- `wide` instructions use 2-byte magic (`b"\xc4\xXX"`) +- Instructions are parsed via a custom `parse_code_instructions` function (not `binrw::helpers::until_eof`) on a length-limited `TakeSeek` sub-stream + +### Why not `until_eof` for instruction parsing + +The `Instruction` enum uses `return_unexpected_error` for concise error messages. This produces `Error::NoVariantMatch` on failure, but `NoVariantMatch.is_eof()` returns `false`. The built-in `until_eof` helper relies on `is_eof()` to detect end-of-stream, so it propagates the error instead of stopping gracefully. The custom `parse_code_instructions` parser handles EOF via a manual 1-byte read check and also computes the correct per-instruction `address` for switch alignment. + +### TargetInfo discriminant + +`TargetInfo` uses `#[br(import(target_type: u8))]` with `pre_assert` on each variant to dispatch based on the `target_type` byte from the enclosing `TypeAnnotation`. This is needed because `TargetInfo` has no magic bytes — the discriminant comes from a sibling field. + +## Compile Feature Architecture + +The `compile` feature (`src/compile/`) is a mini Java-to-bytecode compiler for replacing method bodies. Pipeline: Lexer → Parser → AST → CodeGen. + +### Files + +- `lexer.rs` — Tokenizes Java source (keywords, operators, literals, identifiers) +- `parser.rs` — Recursive-descent parser producing AST nodes +- `ast.rs` — Statement and Expression enums (LocalDecl, If, While, For, Switch, TryCatch, MethodCall, FieldAccess, etc.) +- `codegen.rs` — Emits JVM instructions from AST; manages locals, labels, branch patching, exception tables +- `stackmap.rs` — Tracks verification types (VType) at branch targets; builds StackMapTableAttribute +- `stack_calc.rs` — Computes max stack depth by walking generated instructions +- `patch.rs` — Replaces a method's CodeAttribute in a ClassFile with newly compiled bytecode +- `mod.rs` — Public API: `compile_method_body()`, `CompileOptions`, `CompileError`, `patch_method!`, `patch_methods!` + +### CodeGen internals + +- **Labels**: Branch targets use abstract label IDs, resolved to byte offsets after all instructions are emitted +- **Locals**: `LocalAllocator` assigns JVM local slots, tracking types AND resolved VTypes for StackMapTable; parameters pre-allocated from method descriptor with VTypes resolved against the constant pool at allocation time via `type_name_to_vtype_resolved()`. Supports `save()`/`restore()` for scope-aware local lifetime tracking. +- **Scope management**: `LocalAllocator::save()`/`restore()` ensures locals declared in inner scopes (if-then/else branches, loop bodies, catch handlers, for-each temps, synchronized handlers) don't leak into StackMapTable frames for subsequent code. `LocalDecl` allocates slots AFTER generating the initializer expression so branch targets inside initializers don't include the unassigned local. +- **Switch**: Heuristic selects `tableswitch` vs `lookupswitch` based on density (ratio of cases to range) +- **Try-catch-finally**: Finally blocks are inlined at each exit path; exception table entries built during codegen; `label_locals_override` preserves try-start locals at merge points. Multi-catch uses `java/lang/Throwable` as the stack map type since the LCA cannot be computed without class hierarchy knowledge. +- **StackMapTable**: `CompileOptions::default()` generates StackMapTable frames (all tests run with full JVM verification). `label_locals_override` is used across all control flow structures (if/else, while, for, for-each, switch, switch-expr, try-catch, synchronized) to capture pre-branch locals for merge-point frames. `label_stack_override` handles expression-level merge points (comparisons, ternaries, logical ops, switch expressions) where values are on the stack. `FrameTracker::record_frame()` keeps the last frame when multiple labels share the same bytecode offset. + +### What the compiler supports today + +**Statements**: local declarations (including `var`), expression statements, return (typed), if/else, while, for (traditional), for-each, switch (tableswitch + lookupswitch), try-catch-finally (multiple catches, multi-catch), throw, break, continue, blocks, synchronized + +**Expressions**: int/long/float/double/boolean/char/string/null literals, identifiers, `this`, all binary arithmetic (+, -, *, /, %), bitwise (&, |, ^, ~, <<, >>, >>>), comparisons (==, !=, <, <=, >, >=), logical (&&, || with short-circuit), ternary, switch expressions (arrow syntax), assignment (simple + compound), pre/post increment/decrement, method calls (with invokeinterface detection, generic type params), field access (including `array.length`), array access/creation (including multi-dimensional), object instantiation, casts, instanceof, lambda expressions, method references + +**Parameter access**: Method parameters are accessible by positional name (`arg0`, `arg1`, ...) always. When debug info is present (`javac -g` for `LocalVariableTable`, or `javac -parameters` for `MethodParameters`), original parameter names (e.g., `args`, `name`) are also available as aliases. Wide types (long/double) correctly consume 2 slots. Instance methods have `this` at slot 0; `arg0` is the first declared parameter. + +**Patching**: `compile_method_body()`, `patch_method!`, `patch_methods!`, `prepend_method_body()`, `prepend_method!`, `patch_jar_method!`, `patch_jar_class!`, `patch_jar!` + +**Prepend mode**: `prepend_method_body()` / `prepend_method!` inserts compiled code at the beginning of an existing method body, preserving the original instructions. Handles exception table offset adjustment, StackMapTable frame merging (delta re-encoding), and debug attribute stripping. Trailing returns are auto-stripped so prepended code falls through to the original body. Controlled by `InsertMode::Prepend` in `CompileOptions`. + +## JAR Explorer (`examples/jar_explorer.rs`) + +TUI application for browsing and editing Java `.jar` files. Run with: +```sh +cargo run --example jar_explorer --features tui-example -- path/to/file.jar +``` + +**Browsing**: Tree-based file navigation with expand/collapse. Views `.class` files (decompiled Java or bytecode listing), manifests, text files, nested JARs, hex dumps. Spring Boot format detection. Vim-like navigation (hjkl, gg/G, /search, n/N). + +**Editing**: Press `e` on a loaded `.class` file to enter edit mode. Select a method from the list, type Java source in the editor (`{ ... }` block), then `Ctrl+S` to compile & replace or `Ctrl+P` to compile & prepend. Errors are shown inline; fix and retry. Press `W` to save the modified JAR as `.patched.jar`. + +**Key bindings**: Tree: `hjkl` navigate, `Enter`/`l` open, `e` edit, `W` save, `Tab` switch to viewer. Viewer: vim movement, `/` search, `Tab` back to tree. Edit: `j/k` select method, `Enter` open editor, `Ctrl+S` replace, `Ctrl+P` prepend, `Esc` cancel. + +## Compiler Roadmap + +Prioritized list of missing features. Items marked [done] have been implemented. + +### P0 — High impact (blocks common patching patterns) + +1. [done] **String concatenation with `+`** — `StringBuilder` codegen: `new StringBuilder().append(a).append(b).toString()`. Flattens chained `+` into a single StringBuilder. Type-aware append descriptors via `infer_expr_type`. + +2. [done] **Long/float/double arithmetic** — Type-dispatched binary ops (`ladd`/`fadd`/`dadd` etc.), widening conversions (`i2l`, `i2f`, `i2d`, `l2f`, `l2d`, `f2d`), typed comparisons (`lcmp`, `fcmpl`/`fcmpg`, `dcmpl`/`dcmpg`), typed casts, typed unary ops, typed compound assign and increment/decrement. + +3. [done] **For-each loops** — `for (Type x : iterable)` with both array mode (arraylength + index counter + typed array load) and Iterable mode (invokeinterface iterator/hasNext/next + checkcast). + +4. [done] **Type-aware array load/store** — Correct instruction per element type: `iaload`/`iastore`, `laload`/`lastore`, `faload`/`fastore`, `daload`/`dastore`, `baload`/`bastore`, `caload`/`castore`, `saload`/`sastore`, `aaload`/`aastore`. Fixed array-store stack ordering in assignments. + +**Foundation**: `infer_expr_type` — expression type inference used by all P0 features for type-dispatched codegen, println/append descriptor selection, and widening decisions. + +### P1 — Medium impact (limits what you can patch) + +5. [done] **Multi-catch** — `catch (IOException | SQLException e)` with `|`-separated types in catch clause. Parser collects multiple types; codegen emits separate exception table entries per type, all pointing to the same handler. + +6. [done] **Field assignment on complex receivers** — Category-2 values (long/double) on field stores now use temp-local strategy instead of `Swap` (which only works for category-1). Added `descriptor_to_type` helper. + +7. [done] **Method resolution improvement** — `infer_receiver_class` now resolves method call return types from the constant pool. Instance calls check for `InterfaceMethodRef` in pool and known JDK interfaces to choose `invokeinterface` vs `invokevirtual`. Static field detection also checks `FieldRef` entries in pool before the uppercase heuristic fallback. + +8. [done] **Synchronized blocks** — `synchronized (expr) { ... }` with `monitorenter`/`monitorexit` codegen. Implicit catch-all handler ensures `monitorexit` on exceptional exits (same pattern as try-finally). + +### P2 — Lower priority (nice to have) + +9. [done] **Lambda expressions / method references** — `invokedynamic` + `LambdaMetafactory` bootstrap method generation. Compiles lambda body into synthetic private static method, sets up bootstrap methods attribute, emits `invokedynamic`. Supports typed and inferred params, expression and block bodies. Method references via `Class::method` syntax. + +10. [done] **`var` keyword** (Java 10+) — Parser emits sentinel type `TypeName::Class("__var__")`, codegen resolves via `infer_expr_type`. Requires initializer. + +11. [done] **Switch expressions** (Java 14+) — Arrow syntax `case 1 -> expr;` with `SwitchExpr` AST variant. Reuses tableswitch/lookupswitch infrastructure. Each case pushes value and jumps to end. Requires default case. + +12. [done] **Multi-dimensional array creation** — `new int[3][4]` using `Multianewarray` instruction. `NewMultiArray` AST variant with dimension list. Parser detects consecutive `[expr]` after first dimension. + +13. [done] **Generic type parameters in method calls** — `obj.method()` parse-and-discard. `skip_type_parameters()` handles nested `<>` including `>>` closing. Works in both pre-name and post-name positions in dotted postfix. + +### P3 — Code insertion + +14. [done] **Prepend mode** — `prepend_method_body()` / `prepend_method!` inserts compiled code before an existing method body. Handles exception table offset adjustment, StackMapTable frame merging (absolute offset conversion + delta re-encoding), debug attribute stripping. Trailing returns auto-stripped for fall-through. Append mode (insert after) deferred — requires modifying existing return instructions. + +## Current State + +### What's done +- All type structs: ClassFile, ConstantInfo (19 types including Module/Package), FieldInfo, MethodInfo, AttributeInfo, Instruction +- All 30 attribute variant types (including Module, ModulePackages, ModuleMainClass, NestHost, NestMembers, Record, PermittedSubclasses) +- Custom BinRead for ClassFile (handles const pool Double/Long sentinel entries) +- Full BinWrite support for all types (read-write round-trip verified) +- `sync_from_parsed()` / `sync_counts()` / `sync_lengths()` for patching and rewriting class files +- Recursive `interpret_inner` propagation to Code and Record sub-attributes +- Higher-level API helpers on ClassFile and MethodInfo +- Legacy nom parser files fully removed; builds on stable Rust +- Compile feature: lexer, parser, AST, codegen, StackMapTable generation, method patching macros +- Compile P0 complete: string concat, typed arithmetic (long/float/double), for-each loops, type-aware arrays, expression type inference +- Compile P1 complete: multi-catch, synchronized blocks, field assignment fix for cat-2 values, method resolution (invokeinterface, pool-based receiver inference, pool-based static detection, descriptor inference fallback for unknown methods) +- Compile P2 complete: `var` keyword (type inference), switch expressions (arrow syntax), multi-dimensional arrays (multianewarray), generic type params (parse-and-discard), lambda expressions (invokedynamic + synthetic methods + bootstrap methods), method references +- Constant pool helpers: `get_or_add_method_handle`, `get_or_add_method_type`, `get_or_add_invoke_dynamic` +- JAR patching: `patch_jar_method!`, `patch_jar_class!`, `patch_jar!` macros with E2E tests +- StackMapTable: VType resolution for reference-type locals (including parameters) uses constant pool indices; for-each loop variables allocated after loop-top label for correct frame generation; category-2 types (Long/Double) correctly omit implicit Top continuation slots in frame encoding +- Method descriptor inference: `find_method_descriptor_in_pool()` falls back to `infer_method_descriptor()` for methods not in the pool, with well-known collection method signatures and heuristic return types +- `infer_expr_type()` for MethodCall/StaticMethodCall now resolves return types from method descriptors when available +- Method references resolve descriptors from constant pool, with well-known fallbacks; functional interface and SAM descriptor derived from resolved types +- Robust attribute parsing: `interpret_inner` uses bounds-checked const pool access and graceful error handling (malformed attributes fall back to raw bytes instead of panicking) +- `sync_counts()` uses checked u16 conversion to prevent silent overflow +- 290+ tests passing across all test files +- Compiler tests split into submodules: `cargo test --test compiler parser::` / `e2e::` / `stress::` / `param_access::` / `prepend::` +- JAR Explorer TUI (`examples/jar_explorer.rs`): interactive browsing + code editing with compile & prepend support, save modified JARs diff --git a/Cargo.toml b/Cargo.toml index a91e7b8..78561ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,11 +11,38 @@ readme = "README.md" license = "MIT" exclude = ["java-assets/out/**/*"] +[package.metadata.docs.rs] +all-features = true + +[features] +decompile = [] +compile = ["decompile"] +jar-utils = ["dep:zip"] +jar-patch = ["compile", "jar-utils"] +spring-utils = ["jar-utils"] +tui-example = ["jar-utils", "spring-utils", "compile", "dep:ratatui", "dep:crossterm", "dep:tui-textarea"] + [dependencies] -nom = "^7" bitflags = "^2.3" cesu8 = "^1.1" binrw = "0.15.0" +zip = { version = "8.1", optional = true } +ratatui = { version = "0.29", features = ["crossterm"], optional = true } +crossterm = { version = "0.28", optional = true } +tui-textarea = { version = "0.7", features = ["crossterm", "search"], optional = true } [dev-dependencies] assert_matches = "1.5.0" +tempfile = "3" + +[[example]] +name = "jar_explorer" +required-features = ["tui-example"] + +[[example]] +name = "compile_patch" +required-features = ["compile"] + +[[example]] +name = "jar_patch" +required-features = ["jar-patch"] diff --git a/README.md b/README.md index 088beaf..a389f20 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ ![Rust](https://github.com/Palmr/classfile-parser/workflows/Rust/badge.svg) [![Crates.io Version](https://img.shields.io/crates/v/classfile-parser.svg)](https://crates.io/crates/classfile-parser) -A parser for [Java Classfiles](https://docs.oracle.com/javase/specs/jvms/se10/html/jvms-4.html), written in Rust using [nom](https://github.com/Geal/nom). +A parser for [Java Classfiles](https://docs.oracle.com/javase/specs/jvms/se10/html/jvms-4.html), written in Rust using [binrw](https://github.com/jam1garner/binrw). + +Supports reading, modifying, and writing class files with full round-trip fidelity. Optional features provide JAR archive handling and Spring Boot fat JAR support. ## Installation @@ -15,42 +17,293 @@ Classfile Parser is available from crates.io and can be included in your Cargo e classfile-parser = "~0.3" ``` +### Optional features + +```toml +# Bytecode-to-Java source decompiler +classfile-parser = { version = "~0.3", features = ["decompile"] } + +# JAR archive reading/writing +classfile-parser = { version = "~0.3", features = ["jar-utils"] } + +# Spring Boot fat JAR support (includes jar-utils) +classfile-parser = { version = "~0.3", features = ["spring-utils"] } + +# Java-to-bytecode compiler for patching method bodies (includes decompile) +classfile-parser = { version = "~0.3", features = ["compile"] } + +# Patch methods inside JAR files (includes compile + jar-utils) +classfile-parser = { version = "~0.3", features = ["jar-patch"] } +``` + ## Usage +### Parsing a class file + ```rust -extern crate classfile_parser; +use classfile_parser::ClassFile; +use binrw::BinRead; +use std::io::Cursor; + +fn main() { + let classfile_bytes = include_bytes!("../path/to/JavaClass.class"); + let class_file = ClassFile::read(&mut Cursor::new(classfile_bytes)) + .expect("Failed to parse class file"); + + println!( + "version {},{} const_pool({}), this=const[{}], super=const[{}], \ + interfaces({}), fields({}), methods({}), attributes({}), access({:?})", + class_file.major_version, + class_file.minor_version, + class_file.const_pool_size, + class_file.this_class, + class_file.super_class, + class_file.interfaces_count, + class_file.fields_count, + class_file.methods_count, + class_file.attributes_count, + class_file.access_flags + ); -use classfile_parser::class_parser; + // Look up names via the constant pool + if let Some(name) = class_file.get_utf8(class_file.this_class) { + println!("Class name: {name}"); + } +} +``` + +### Modifying and writing a class file + +```rust +use classfile_parser::ClassFile; +use classfile_parser::code_attribute::Instruction; +use binrw::{BinRead, BinWrite}; +use std::io::Cursor; fn main() { let classfile_bytes = include_bytes!("../path/to/JavaClass.class"); + let mut class_file = ClassFile::read(&mut Cursor::new(classfile_bytes)) + .expect("Failed to parse class file"); + + // Find a method and modify its bytecode + if let Some(method) = class_file.find_method_mut("main") { + method.with_code(|code| { + // Replace the first instruction with a nop + code.replace_instruction(0, Instruction::Nop); + }); + } + + // Sync counts and write back out + class_file.sync_all().expect("sync failed"); + let mut output = Cursor::new(Vec::new()); + class_file.write(&mut output).expect("write failed"); +} +``` + +### Working with JAR files + +Requires the `jar-utils` feature. An example script for this feature can be run with `cargo run --example jar_explorer --features tui-example -- path/to/your/jar/file.jar`. - match class_parser(classfile_bytes) { - Ok((_, class_file)) => { - println!( - "version {},{} \ - const_pool({}), \ - this=const[{}], \ - super=const[{}], \ - interfaces({}), \ - fields({}), \ - methods({}), \ - attributes({}), \ - access({:?})", - class_file.major_version, - class_file.minor_version, - class_file.const_pool_size, - class_file.this_class, - class_file.super_class, - class_file.interfaces_count, - class_file.fields_count, - class_file.methods_count, - class_file.attributes_count, - class_file.access_flags - ); +```rust +use classfile_parser::jar_utils::JarFile; + +fn main() { + let jar = JarFile::open("path/to/file.jar").expect("Failed to open JAR"); + + for name in jar.entry_names() { + println!("{name}"); + } + + // Parse a class directly from the JAR + let class_file = jar.parse_class("com/example/Main.class") + .expect("Failed to parse class"); + + // Read and modify the manifest + if let Ok(Some(manifest)) = jar.manifest() { + if let Some(main_class) = manifest.main_attr("Main-Class") { + println!("Main-Class: {main_class}"); } - Err(_) => panic!("Failed to parse"), + } +} +``` + +### Decompiling class files + +Requires the `decompile` feature. Converts parsed bytecode back to readable Java source. + +```rust +use classfile_parser::ClassFile; +use classfile_parser::decompile::{self, Decompiler, DecompileOptions, RenderConfig}; +use binrw::BinRead; +use std::io::Cursor; + +fn main() { + let bytes = include_bytes!("../path/to/JavaClass.class"); + let class = ClassFile::read(&mut Cursor::new(bytes)) + .expect("Failed to parse class file"); + + // Quick one-liner with default options + let source = decompile::decompile(&class).expect("decompilation failed"); + println!("{source}"); + + // Or configure the decompiler + let options = DecompileOptions { + render_config: RenderConfig { + indent: " ".into(), + max_line_width: 100, + use_var: false, + include_synthetic: false, + }, + include_synthetic: false, + ..Default::default() }; + let decompiler = Decompiler::new(options); + let source = decompiler.decompile(&class).expect("decompilation failed"); + println!("{source}"); + + // Decompile a single method + let method_src = decompiler.decompile_method(&class, "main") + .expect("method decompilation failed"); + println!("{method_src}"); +} +``` + +### Compiling and patching method bodies + +Requires the `compile` feature. Replaces method bodies in an existing class file with new Java source compiled directly to bytecode — no `javac` needed at runtime. + +The `patch_method!` macro is the easiest way to patch a method. It compiles the given Java method body, generates a valid StackMapTable, and replaces the named method in one call: + +```rust +use classfile_parser::{patch_method, ClassFile}; + +fn main() { + let bytes = std::fs::read("HelloWorld.class").unwrap(); + let mut class_file = ClassFile::from_bytes(&bytes).unwrap(); + + patch_method!(class_file, "main", r#"{ + System.out.println("patched!"); + }"#).unwrap(); + + std::fs::write("HelloWorld.class", class_file.to_bytes().unwrap()).unwrap(); +} +``` + +Use `patch_methods!` to patch several methods at once: + +```rust +use classfile_parser::{patch_methods, ClassFile}; + +fn main() { + let bytes = std::fs::read("MyClass.class").unwrap(); + let mut cf = ClassFile::from_bytes(&bytes).unwrap(); + + patch_methods!(cf, { + "main" => r#"{ + int x = 42; + if (x > 10) { + System.out.println("big"); + } else { + System.out.println("small"); + } + }"#, + "helper" => r#"{ return 99; }"#, + }).unwrap(); + + std::fs::write("MyClass.class", cf.to_bytes().unwrap()).unwrap(); +} +``` + +Both macros generate a StackMapTable by default, so patched classes pass full JVM bytecode verification. Pass `no_verify` to skip generation if you'll run with `-noverify`: + +```rust +patch_method!(class_file, "main", r#"{ return; }"#, no_verify).unwrap(); +``` + +The compiler supports: local variables, arithmetic, if/else, while, for, break/continue, switch (tableswitch/lookupswitch), try-catch-finally, return, throw, method calls, field access, object creation, arrays, casts, instanceof, and ternary expressions. + +A full working example is at [`examples/compile_patch.rs`](examples/compile_patch.rs): + +```sh +cargo run --example compile_patch --features compile +``` + +### Patching methods inside JAR files + +Requires the `jar-patch` feature (which enables both `compile` and `jar-utils`). Patch method bodies directly inside a JAR without extracting or recompiling — no `javac` needed. + +Patch a single method with `patch_jar_method!`: + +```rust +use classfile_parser::jar_utils::JarFile; +use classfile_parser::patch_jar_method; + +fn main() { + let mut jar = JarFile::open("app.jar").unwrap(); + + patch_jar_method!(jar, "com/example/Main.class", "main", r#"{ + System.out.println("patched!"); + }"#).unwrap(); + + jar.save("app-patched.jar").unwrap(); +} +``` + +Use `patch_jar!` to batch patches across multiple classes — each class is parsed once, all its methods are patched, and the class is written back once: + +```rust +use classfile_parser::jar_utils::JarFile; +use classfile_parser::patch_jar; + +fn main() { + let mut jar = JarFile::open("app.jar").unwrap(); + + patch_jar!(jar, { + "com/example/Main.class" => { + "main" => r#"{ System.out.println("patched main"); }"#, + "helper" => r#"{ return 42; }"#, + }, + "com/example/Util.class" => { + "compute" => r#"{ return 0; }"#, + }, + }).unwrap(); + + jar.save("app-patched.jar").unwrap(); +} +``` + +All JAR patching macros generate a StackMapTable by default. Pass `no_verify` to skip: + +```rust +patch_jar_method!(jar, "Main.class", "main", r#"{ return; }"#, no_verify).unwrap(); +``` + +A full working example is at [`examples/jar_patch.rs`](examples/jar_patch.rs): + +```sh +cargo run --example jar_patch --features jar-patch +``` + +### Spring Boot fat JARs + +Requires the `spring-utils` feature. + +```rust +use classfile_parser::spring_utils::SpringBootJar; + +fn main() { + if let Ok(Some(sb)) = SpringBootJar::open("path/to/app.jar") { + println!("Format: {:?}", sb.format()); + println!("Start-Class: {:?}", sb.start_class()); + + for name in sb.app_class_names() { + println!(" {name}"); + } + + for name in sb.nested_jar_names() { + println!(" lib: {name}"); + } + } } ``` @@ -76,6 +329,8 @@ fn main() { - [x] MethodHandle - [x] MethodType - [x] InvokeDynamic + - [x] Module + - [x] Package - [x] Access flags - [x] This class - [x] Super class @@ -103,7 +358,7 @@ fn main() { - [x] RuntimeVisibleTypeAnnotations - [x] RuntimeInvisibleTypeAnnotations - [x] AnnotationDefault - - [X] MethodParameters + - [x] MethodParameters - [x] Useful but not critical - [x] SourceFile - [~] SourceDebugExtension @@ -111,3 +366,55 @@ fn main() { - [x] LocalVariableTable - [x] LocalVariableTypeTable - [x] Deprecated + - [x] Java 9+ module system + - [x] Module + - [x] ModulePackages + - [x] ModuleMainClass + - [x] Java 11+ nesting + - [x] NestHost + - [x] NestMembers + - [x] Java 16+ records and sealed classes + - [x] Record + - [x] PermittedSubclasses +- [x] Instructions + - [x] All 200+ JVM opcodes + - [x] Wide instruction variants + - [x] tableswitch / lookupswitch with alignment padding +- [x] Read-write round-trip support (BinRead + BinWrite) +- [x] Patching support + - [x] `sync_from_parsed()` for attribute reserialization + - [x] `sync_counts()` / `sync_lengths()` for count field recalculation + - [x] `sync_all()` for full class file resync + - [x] Instruction replacement and nop-out helpers + - [x] Constant pool addition helpers (`add_utf8`, `add_string`, `add_class`, etc.) +- [x] JAR utilities (optional `jar-utils` feature) + - [x] Read/write JAR archives + - [x] Parse class files directly from JARs + - [x] Manifest parsing and serialization +- [x] Spring Boot support (optional `spring-utils` feature) + - [x] Detect JAR/WAR format + - [x] Application class and resource enumeration + - [x] Nested JAR access + - [x] classpath.idx and layers.idx parsing +- [x] Decompiler (optional `decompile` feature) + - [x] Control flow graph construction from bytecode + - [x] Stack simulation and expression tree recovery + - [x] Control flow structuring (if/else, while, for, switch, try-catch) + - [x] Type inference, generics, and annotation recovery + - [x] Java source rendering with import management + - [x] Record, sealed class, and enum support + - [x] Per-method error recovery with bytecode fallback + - [x] Inner class decompilation + - [x] Compiler desugaring (autoboxing, for-each, assert) +- [x] Compiler (optional `compile` feature) + - [x] Java method body lexer, parser, and AST + - [x] Bytecode codegen (locals, arithmetic, comparisons, logical ops) + - [x] Control flow: if/else, while, for, break/continue, switch (tableswitch/lookupswitch) + - [x] Exception handling: try-catch-finally with exception table generation + - [x] Object creation, method calls, field access, arrays, casts, instanceof, ternary + - [x] StackMapTable generation for full JVM bytecode verification + - [x] Method body patching (`compile_method_body`, `patch_method!`, `patch_methods!`) +- [x] JAR patching (optional `jar-patch` feature) + - [x] Patch methods directly inside JAR archives + - [x] Batch by class — parse once, patch N methods, write once + - [x] `patch_jar_method!`, `patch_jar_class!`, `patch_jar!` macros diff --git a/examples/compile_patch.rs b/examples/compile_patch.rs new file mode 100644 index 0000000..63b809f --- /dev/null +++ b/examples/compile_patch.rs @@ -0,0 +1,123 @@ +//! Example: patch a compiled Java class using the compile feature. +//! +//! This example: +//! 1. Compiles a minimal Java class with `javac` +//! 2. Parses the resulting `.class` file with `ClassFile::from_bytes` +//! 3. Replaces method bodies using `patch_method!` / `patch_methods!` +//! 4. Writes the modified class back with `ClassFile::to_bytes` +//! 5. Runs it with `java` to show the new behavior +//! +//! Run with: +//! cargo run --example compile_patch --features compile + +use std::fs; +use std::process::Command; + +use classfile_parser::{ClassFile, patch_method, patch_methods}; + +fn main() { + // ── Step 1: Create and compile a Java class with two methods ───────── + let tmp_dir = std::env::temp_dir().join("classfile_compile_example"); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let java_src = tmp_dir.join("HelloWorld.java"); + fs::write( + &java_src, + r#" +public class HelloWorld { + public static void greet() { + System.out.println("original greet"); + } + + public static void main(String[] args) { + System.out.println("original main"); + greet(); + } +} +"#, + ) + .unwrap(); + + let javac = Command::new("javac") + .arg("-d") + .arg(&tmp_dir) + .arg(&java_src) + .output() + .expect("javac not found — make sure a JDK is on your PATH"); + assert!( + javac.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&javac.stderr) + ); + println!("Compiled HelloWorld.java"); + + // ── Step 2: Parse the .class file ──────────────────────────────────── + let class_path = tmp_dir.join("HelloWorld.class"); + let bytes = fs::read(&class_path).unwrap(); + let mut class_file = ClassFile::from_bytes(&bytes).expect("failed to parse class"); + println!("Parsed HelloWorld.class ({} bytes)", bytes.len()); + + // ── Step 3a: Patch a single method ─────────────────────────────────── + // + // patch_method! compiles a Java method body and replaces the named method. + // By default it generates a StackMapTable so the class passes full JVM + // bytecode verification. Use `no_verify` as a 4th argument to skip that. + patch_method!( + class_file, + "greet", + r#"{ + System.out.println("patched greet!"); + }"# + ) + .unwrap(); + println!("Patched greet()"); + + // ── Step 3b: Patch multiple methods at once ────────────────────────── + // + // patch_methods! patches several methods in one call. Methods are compiled + // in order; an error on any method stops immediately. + patch_methods!(class_file, { + "main" => r#"{ + int x = 42; + switch (x) { + case 1: System.out.println("one"); break; + case 42: System.out.println("forty-two"); break; + default: System.out.println("other"); break; + } + + try { + throw new RuntimeException("boom"); + } catch (RuntimeException e) { + System.out.println("caught exception"); + } + + greet(); + }"#, + }) + .unwrap(); + println!("Patched main()"); + + // ── Step 4: Write back and run ─────────────────────────────────────── + let patched_bytes = class_file.to_bytes().expect("failed to serialize"); + fs::write(&class_path, &patched_bytes).unwrap(); + println!("Wrote patched class ({} bytes)\n", patched_bytes.len()); + + let run = Command::new("java") + .arg("-cp") + .arg(&tmp_dir) + .arg("HelloWorld") + .output() + .expect("java not found"); + if run.status.success() { + println!("{}", String::from_utf8_lossy(&run.stdout).trim()); + } else { + eprintln!( + "java failed: {}", + String::from_utf8_lossy(&run.stderr).trim() + ); + std::process::exit(1); + } + + let _ = fs::remove_dir_all(&tmp_dir); +} diff --git a/examples/jar_explorer.rs b/examples/jar_explorer.rs new file mode 100644 index 0000000..6b3c7bf --- /dev/null +++ b/examples/jar_explorer.rs @@ -0,0 +1,1940 @@ +//! TUI JAR Explorer — interactive browser for Java `.jar` files. +//! +//! ```sh +//! cargo run --example jar_explorer --features tui-example -- path/to/file.jar +//! ``` + +use std::collections::BTreeMap; +use std::io::{self, Cursor}; + +use binrw::BinRead; +use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers}; +use crossterm::terminal::{self, EnterAlternateScreen, LeaveAlternateScreen}; +use ratatui::Terminal; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::{Constraint, Direction, Layout, Rect}; +use ratatui::style::{Color, Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{Block, Borders, List, ListItem, Paragraph}; +use tui_textarea::{CursorMove, TextArea}; + +use classfile_parser::attribute_info::{ + AttributeInfoVariant, CodeAttribute, ExceptionEntry, LineNumberTableAttribute, +}; +use classfile_parser::code_attribute::Instruction; +use classfile_parser::compile::{CompileOptions, compile_method_body, prepend_method_body}; +use classfile_parser::constant_info::ConstantInfo; +use classfile_parser::field_info::FieldAccessFlags; +use classfile_parser::jar_utils::{JarFile, JarManifest}; +use classfile_parser::method_info::MethodAccessFlags; +use classfile_parser::spring_utils::{SpringBootFormat, detect_format}; +use classfile_parser::{ClassAccessFlags, ClassFile}; + +// --------------------------------------------------------------------------- +// Data structures +// --------------------------------------------------------------------------- + +struct TreeNode { + label: String, + entry_path: Option, + depth: usize, + expanded: bool, + is_dir: bool, +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum Focus { + Tree, + Viewer, +} + +#[derive(Clone, PartialEq, Eq)] +enum VimMode { + Normal, + Search, + Pending(char), +} + +enum EditState<'a> { + SelectMethod { + entry_path: String, + methods: Vec<(String, String)>, // (name, formatted_signature) + selected: usize, + }, + EditCode { + entry_path: String, + method_name: String, + editor: TextArea<'a>, + error_message: Option, + }, +} + +struct App<'a> { + jar: JarFile, + jar_path: String, + spring_format: Option, + tree: Vec, + tree_selected: usize, + tree_scroll: usize, + viewer: TextArea<'a>, + vim_mode: VimMode, + search_buffer: String, + focus: Focus, + viewer_title: String, + loaded_entry: Option, + status_message: String, + should_quit: bool, + edit_state: Option>, + has_unsaved_changes: bool, +} + +// --------------------------------------------------------------------------- +// Tree building +// --------------------------------------------------------------------------- + +fn build_tree(jar: &JarFile) -> Vec { + let mut nodes: Vec = Vec::new(); + let mut dir_indices: BTreeMap = BTreeMap::new(); + + for name in jar.entry_names() { + let parts: Vec<&str> = name.split('/').collect(); + + // Ensure all ancestor directories exist + let mut accumulated = String::new(); + for (i, &part) in parts.iter().enumerate() { + if i < parts.len() - 1 { + // directory component + if !accumulated.is_empty() { + accumulated.push('/'); + } + accumulated.push_str(part); + let dir_key = accumulated.clone(); + if !dir_indices.contains_key(&dir_key) { + let idx = nodes.len(); + nodes.push(TreeNode { + label: format!("{}/", part), + entry_path: None, + depth: i, + expanded: true, + is_dir: true, + }); + dir_indices.insert(dir_key, idx); + } + } + } + + // Leaf file entry + let depth = parts.len() - 1; + nodes.push(TreeNode { + label: parts.last().unwrap_or(&name).to_string(), + entry_path: Some(name.to_string()), + depth, + expanded: false, + is_dir: false, + }); + } + + nodes +} + +fn visible_indices(tree: &[TreeNode]) -> Vec { + let mut visible = Vec::new(); + let mut skip_depth: Option = None; + + for (i, node) in tree.iter().enumerate() { + if let Some(sd) = skip_depth { + if node.depth > sd { + continue; + } else { + skip_depth = None; + } + } + visible.push(i); + if node.is_dir && !node.expanded { + skip_depth = Some(node.depth); + } + } + visible +} + +// --------------------------------------------------------------------------- +// Content formatters +// --------------------------------------------------------------------------- + +fn get_utf8(const_pool: &[ConstantInfo], index: u16) -> String { + if index == 0 { + return "".to_string(); + } + match const_pool.get((index - 1) as usize) { + Some(ConstantInfo::Utf8(u)) => u.utf8_string.clone(), + _ => format!("#{index}"), + } +} + +fn get_class_name(const_pool: &[ConstantInfo], index: u16) -> String { + if index == 0 { + return "".to_string(); + } + match const_pool.get((index - 1) as usize) { + Some(ConstantInfo::Class(c)) => get_utf8(const_pool, c.name_index), + _ => format!("#{index}"), + } +} + +fn get_name_and_type(const_pool: &[ConstantInfo], index: u16) -> (String, String) { + match const_pool.get((index - 1) as usize) { + Some(ConstantInfo::NameAndType(nat)) => ( + get_utf8(const_pool, nat.name_index), + get_utf8(const_pool, nat.descriptor_index), + ), + _ => (format!("#{index}"), String::new()), + } +} + +fn resolve_ref(const_pool: &[ConstantInfo], class_idx: u16, nat_idx: u16) -> String { + let class = get_class_name(const_pool, class_idx); + let (name, desc) = get_name_and_type(const_pool, nat_idx); + format!("{class}.{name}:{desc}") +} + +fn format_method_access(flags: MethodAccessFlags) -> String { + let mut parts = Vec::new(); + if flags.contains(MethodAccessFlags::PUBLIC) { + parts.push("public"); + } + if flags.contains(MethodAccessFlags::PRIVATE) { + parts.push("private"); + } + if flags.contains(MethodAccessFlags::PROTECTED) { + parts.push("protected"); + } + if flags.contains(MethodAccessFlags::STATIC) { + parts.push("static"); + } + if flags.contains(MethodAccessFlags::FINAL) { + parts.push("final"); + } + if flags.contains(MethodAccessFlags::SYNCHRONIZED) { + parts.push("synchronized"); + } + if flags.contains(MethodAccessFlags::NATIVE) { + parts.push("native"); + } + if flags.contains(MethodAccessFlags::ABSTRACT) { + parts.push("abstract"); + } + parts.join(" ") +} + +fn format_field_access(flags: FieldAccessFlags) -> String { + let mut parts = Vec::new(); + if flags.contains(FieldAccessFlags::PUBLIC) { + parts.push("public"); + } + if flags.contains(FieldAccessFlags::PRIVATE) { + parts.push("private"); + } + if flags.contains(FieldAccessFlags::PROTECTED) { + parts.push("protected"); + } + if flags.contains(FieldAccessFlags::STATIC) { + parts.push("static"); + } + if flags.contains(FieldAccessFlags::FINAL) { + parts.push("final"); + } + if flags.contains(FieldAccessFlags::VOLATILE) { + parts.push("volatile"); + } + if flags.contains(FieldAccessFlags::TRANSIENT) { + parts.push("transient"); + } + parts.join(" ") +} + +fn format_class_access(flags: ClassAccessFlags) -> String { + let mut parts = Vec::new(); + if flags.contains(ClassAccessFlags::PUBLIC) { + parts.push("public"); + } + if flags.contains(ClassAccessFlags::FINAL) { + parts.push("final"); + } + if flags.contains(ClassAccessFlags::ABSTRACT) { + parts.push("abstract"); + } + if flags.contains(ClassAccessFlags::INTERFACE) { + parts.push("interface"); + } + if flags.contains(ClassAccessFlags::ENUM) { + parts.push("enum"); + } + if flags.contains(ClassAccessFlags::ANNOTATION) { + parts.push("annotation"); + } + if flags.contains(ClassAccessFlags::MODULE) { + parts.push("module"); + } + if flags.contains(ClassAccessFlags::SYNTHETIC) { + parts.push("synthetic"); + } + parts.join(" ") +} + +fn descriptor_to_readable(desc: &str) -> String { + // Simple best-effort conversion of JVM type descriptors to readable form. + let mut out = String::new(); + let mut chars = desc.chars().peekable(); + while let Some(c) = chars.next() { + match c { + 'B' => out.push_str("byte"), + 'C' => out.push_str("char"), + 'D' => out.push_str("double"), + 'F' => out.push_str("float"), + 'I' => out.push_str("int"), + 'J' => out.push_str("long"), + 'S' => out.push_str("short"), + 'Z' => out.push_str("boolean"), + 'V' => out.push_str("void"), + '[' => { + let inner = descriptor_to_readable(&chars.collect::()); + return format!("{out}{inner}[]"); + } + 'L' => { + let class_name: String = chars.by_ref().take_while(|&ch| ch != ';').collect(); + out.push_str(&class_name.replace('/', ".")); + } + '(' => out.push('('), + ')' => out.push(')'), + _ => { + out.push(c); + } + } + } + out +} + +fn format_instruction(instr: &Instruction, const_pool: &[ConstantInfo]) -> String { + match instr { + // Invoke instructions — resolve to symbolic names + Instruction::Invokevirtual(idx) => { + if let Some(ConstantInfo::MethodRef(mr)) = const_pool.get((*idx - 1) as usize) { + format!( + "invokevirtual {}", + resolve_ref(const_pool, mr.class_index, mr.name_and_type_index) + ) + } else { + format!("invokevirtual #{idx}") + } + } + Instruction::Invokespecial(idx) => { + if let Some(ConstantInfo::MethodRef(mr)) = const_pool.get((*idx - 1) as usize) { + format!( + "invokespecial {}", + resolve_ref(const_pool, mr.class_index, mr.name_and_type_index) + ) + } else { + format!("invokespecial #{idx}") + } + } + Instruction::Invokestatic(idx) => { + if let Some(ConstantInfo::MethodRef(mr)) = const_pool.get((*idx - 1) as usize) { + format!( + "invokestatic {}", + resolve_ref(const_pool, mr.class_index, mr.name_and_type_index) + ) + } else { + format!("invokestatic #{idx}") + } + } + Instruction::Invokeinterface { index, count, .. } => { + if let Some(ConstantInfo::InterfaceMethodRef(mr)) = + const_pool.get((*index - 1) as usize) + { + format!( + "invokeinterface {} count={count}", + resolve_ref(const_pool, mr.class_index, mr.name_and_type_index) + ) + } else { + format!("invokeinterface #{index} count={count}") + } + } + Instruction::Invokedynamic { index, .. } => { + if let Some(ConstantInfo::InvokeDynamic(id)) = const_pool.get((*index - 1) as usize) { + let (name, desc) = get_name_and_type(const_pool, id.name_and_type_index); + format!( + "invokedynamic #{} {name}:{desc}", + id.bootstrap_method_attr_index + ) + } else { + format!("invokedynamic #{index}") + } + } + + // Field access + Instruction::Getfield(idx) + | Instruction::Getstatic(idx) + | Instruction::Putfield(idx) + | Instruction::Putstatic(idx) => { + let opname = match instr { + Instruction::Getfield(_) => "getfield", + Instruction::Getstatic(_) => "getstatic", + Instruction::Putfield(_) => "putfield", + Instruction::Putstatic(_) => "putstatic", + _ => unreachable!(), + }; + if let Some(ConstantInfo::FieldRef(fr)) = const_pool.get((*idx - 1) as usize) { + format!( + "{opname} {}", + resolve_ref(const_pool, fr.class_index, fr.name_and_type_index) + ) + } else { + format!("{opname} #{idx}") + } + } + + // Type operations + Instruction::New(idx) => format!("new {}", get_class_name(const_pool, *idx)), + Instruction::Checkcast(idx) => format!("checkcast {}", get_class_name(const_pool, *idx)), + Instruction::Instanceof(idx) => format!("instanceof {}", get_class_name(const_pool, *idx)), + Instruction::Anewarray(idx) => format!("anewarray {}", get_class_name(const_pool, *idx)), + + // LDC + Instruction::Ldc(idx) => format!("ldc {}", format_constant(const_pool, *idx as u16)), + Instruction::LdcW(idx) => format!("ldc_w {}", format_constant(const_pool, *idx)), + Instruction::Ldc2W(idx) => format!("ldc2_w {}", format_constant(const_pool, *idx)), + + // Fallback: Debug formatting + other => format!("{:?}", other).to_lowercase(), + } +} + +fn format_constant(const_pool: &[ConstantInfo], index: u16) -> String { + match const_pool.get((index - 1) as usize) { + Some(ConstantInfo::String(s)) => { + format!("\"{}\"", get_utf8(const_pool, s.string_index)) + } + Some(ConstantInfo::Integer(i)) => format!("{}", i.value), + Some(ConstantInfo::Float(f)) => format!("{}f", f.value), + Some(ConstantInfo::Long(l)) => format!("{}L", l.value), + Some(ConstantInfo::Double(d)) => format!("{}d", d.value), + Some(ConstantInfo::Class(c)) => format!("class {}", get_utf8(const_pool, c.name_index)), + Some(ConstantInfo::MethodType(mt)) => { + format!("methodtype {}", get_utf8(const_pool, mt.descriptor_index)) + } + _ => format!("#{index}"), + } +} + +fn format_class(jar: &JarFile, path: &str) -> Vec { + let data = match jar.get_entry(path) { + Some(d) => d, + None => return vec![format!("Entry not found: {path}")], + }; + + let cf = match ClassFile::read(&mut Cursor::new(data)) { + Ok(c) => c, + Err(e) => return vec![format!("Failed to parse class: {e}")], + }; + + // Try decompilation first + match classfile_parser::decompile::decompile(&cf) { + Ok(source) => source.lines().map(|l| l.to_string()).collect(), + Err(e) => { + let mut lines = vec![ + format!("// Decompilation failed: {e}"), + "// Falling back to bytecode view".to_string(), + String::new(), + ]; + lines.extend(format_class_bytecode(&cf)); + lines + } + } +} + +fn format_class_bytecode(cf: &ClassFile) -> Vec { + let cp = &cf.const_pool; + let mut lines = Vec::new(); + + let this_class = get_class_name(cp, cf.this_class); + let super_class = get_class_name(cp, cf.super_class); + let java_version = match cf.major_version { + 45 => "1.1", + 46 => "1.2", + 47 => "1.3", + 48 => "1.4", + 49 => "5", + 50 => "6", + 51 => "7", + 52 => "8", + 53 => "9", + 54 => "10", + 55 => "11", + 56 => "12", + 57 => "13", + 58 => "14", + 59 => "15", + 60 => "16", + 61 => "17", + 62 => "18", + 63 => "19", + 64 => "20", + 65 => "21", + 66 => "22", + 67 => "23", + 68 => "24", + _ => "?", + }; + + lines.push(format!("=== Class: {} ===", this_class)); + lines.push(format!( + "Version: {}.{} (Java {java_version})", + cf.major_version, cf.minor_version + )); + lines.push(format!( + "Access: {}", + format_class_access(cf.access_flags) + )); + lines.push(format!("Super: {super_class}")); + + // Interfaces + if !cf.interfaces.is_empty() { + lines.push(String::new()); + lines.push(format!("--- Interfaces ({}) ---", cf.interfaces.len())); + for &iface in &cf.interfaces { + lines.push(format!(" {}", get_class_name(cp, iface))); + } + } + + // Fields + if !cf.fields.is_empty() { + lines.push(String::new()); + lines.push(format!("--- Fields ({}) ---", cf.fields.len())); + for field in &cf.fields { + let name = get_utf8(cp, field.name_index); + let desc = get_utf8(cp, field.descriptor_index); + let access = format_field_access(field.access_flags); + lines.push(format!( + " {access} {} {name}", + descriptor_to_readable(&desc) + )); + } + } + + // Methods + if !cf.methods.is_empty() { + lines.push(String::new()); + lines.push(format!("--- Methods ({}) ---", cf.methods.len())); + for method in &cf.methods { + let name = get_utf8(cp, method.name_index); + let desc = get_utf8(cp, method.descriptor_index); + let access = format_method_access(method.access_flags); + lines.push(format!( + " {access} {name}{}", + descriptor_to_readable(&desc) + )); + + if let Some(code) = method.code() { + format_code_body(&mut lines, code, cp); + } + } + } + + // Constant pool + lines.push(String::new()); + lines.push(format!("--- Constant Pool ({}) ---", cp.len())); + for (i, entry) in cp.iter().enumerate() { + let idx = i + 1; + let desc = match entry { + ConstantInfo::Utf8(u) => format!("Utf8 \"{}\"", u.utf8_string), + ConstantInfo::Integer(v) => format!("Integer {}", v.value), + ConstantInfo::Float(v) => format!("Float {}", v.value), + ConstantInfo::Long(v) => format!("Long {}", v.value), + ConstantInfo::Double(v) => format!("Double {}", v.value), + ConstantInfo::Class(c) => format!("Class #{}", c.name_index), + ConstantInfo::String(s) => format!("String #{}", s.string_index), + ConstantInfo::FieldRef(r) => { + format!("Fieldref #{}.#{}", r.class_index, r.name_and_type_index) + } + ConstantInfo::MethodRef(r) => { + format!("Methodref #{}.#{}", r.class_index, r.name_and_type_index) + } + ConstantInfo::InterfaceMethodRef(r) => format!( + "InterfaceMethodref #{}.#{}", + r.class_index, r.name_and_type_index + ), + ConstantInfo::NameAndType(n) => { + format!("NameAndType #{}.#{}", n.name_index, n.descriptor_index) + } + ConstantInfo::MethodHandle(h) => { + format!( + "MethodHandle kind={} #{}", + h.reference_kind, h.reference_index + ) + } + ConstantInfo::MethodType(t) => format!("MethodType #{}", t.descriptor_index), + ConstantInfo::InvokeDynamic(d) => format!( + "InvokeDynamic #{}:#{}", + d.bootstrap_method_attr_index, d.name_and_type_index + ), + ConstantInfo::Module(m) => format!("Module #{}", m.name_index), + ConstantInfo::Package(p) => format!("Package #{}", p.name_index), + ConstantInfo::Unusable => " (unusable)".to_string(), + }; + lines.push(format!(" #{idx:<4} {desc}")); + } + + // Class-level attributes + if !cf.attributes.is_empty() { + lines.push(String::new()); + lines.push(format!("--- Attributes ({}) ---", cf.attributes.len())); + for attr in &cf.attributes { + let attr_name = get_utf8(cp, attr.attribute_name_index); + lines.push(format!(" {attr_name} ({} bytes)", attr.attribute_length)); + } + } + + lines +} + +fn format_code_body(lines: &mut Vec, code: &CodeAttribute, cp: &[ConstantInfo]) { + lines.push(format!( + " max_stack={}, max_locals={}, code_length={}", + code.max_stack, code.max_locals, code.code_length + )); + + // Find line number table if present + let line_table: Option<&LineNumberTableAttribute> = + code.attributes.iter().find_map(|a| match &a.info_parsed { + Some(AttributeInfoVariant::LineNumberTable(t)) => Some(t), + _ => None, + }); + + // Compute per-instruction byte addresses + let mut address = 0u32; + for instr in &code.code { + let line_info = line_table.and_then(|lt| { + lt.line_number_table + .iter() + .find(|e| e.start_pc as u32 == address) + .map(|e| e.line_number) + }); + let line_prefix = match line_info { + Some(ln) => format!("L{ln:<4}"), + None => " ".to_string(), + }; + lines.push(format!( + " {line_prefix} {address:04}: {}", + format_instruction(instr, cp) + )); + address += instruction_byte_size(instr, address); + } + + // Exception table + if !code.exception_table.is_empty() { + lines.push(format!( + " Exception table ({}):", + code.exception_table.len() + )); + for ExceptionEntry { + start_pc, + end_pc, + handler_pc, + catch_type, + } in &code.exception_table + { + let catch = if *catch_type == 0 { + "any".to_string() + } else { + get_class_name(cp, *catch_type) + }; + lines.push(format!( + " {start_pc}-{end_pc} -> {handler_pc} catch {catch}" + )); + } + } +} + +fn instruction_byte_size(instr: &Instruction, address: u32) -> u32 { + match instr { + // 1-byte instructions (no operands) + Instruction::Nop + | Instruction::Aconstnull + | Instruction::Aload0 + | Instruction::Aload1 + | Instruction::Aload2 + | Instruction::Aload3 + | Instruction::Astore0 + | Instruction::Astore1 + | Instruction::Astore2 + | Instruction::Astore3 + | Instruction::Aaload + | Instruction::Aastore + | Instruction::Areturn + | Instruction::Arraylength + | Instruction::Athrow + | Instruction::Baload + | Instruction::Bastore + | Instruction::Caload + | Instruction::Castore + | Instruction::D2f + | Instruction::D2i + | Instruction::D2l + | Instruction::Dadd + | Instruction::Daload + | Instruction::Dastore + | Instruction::Dcmpg + | Instruction::Dcmpl + | Instruction::Dconst0 + | Instruction::Dconst1 + | Instruction::Ddiv + | Instruction::Dload0 + | Instruction::Dload1 + | Instruction::Dload2 + | Instruction::Dload3 + | Instruction::Dmul + | Instruction::Dneg + | Instruction::Drem + | Instruction::Dreturn + | Instruction::Dstore0 + | Instruction::Dstore1 + | Instruction::Dstore2 + | Instruction::Dstore3 + | Instruction::Dsub + | Instruction::Dup + | Instruction::Dupx1 + | Instruction::Dupx2 + | Instruction::Dup2 + | Instruction::Dup2x1 + | Instruction::Dup2x2 + | Instruction::F2d + | Instruction::F2i + | Instruction::F2l + | Instruction::Fadd + | Instruction::Faload + | Instruction::Fastore + | Instruction::Fcmpg + | Instruction::Fcmpl + | Instruction::Fconst0 + | Instruction::Fconst1 + | Instruction::Fconst2 + | Instruction::Fdiv + | Instruction::Fload0 + | Instruction::Fload1 + | Instruction::Fload2 + | Instruction::Fload3 + | Instruction::Fmul + | Instruction::Fneg + | Instruction::Frem + | Instruction::Freturn + | Instruction::Fstore0 + | Instruction::Fstore1 + | Instruction::Fstore2 + | Instruction::Fstore3 + | Instruction::Fsub + | Instruction::I2b + | Instruction::I2c + | Instruction::I2d + | Instruction::I2f + | Instruction::I2l + | Instruction::I2s + | Instruction::Iadd + | Instruction::Iaload + | Instruction::Iand + | Instruction::Iastore + | Instruction::Iconstm1 + | Instruction::Iconst0 + | Instruction::Iconst1 + | Instruction::Iconst2 + | Instruction::Iconst3 + | Instruction::Iconst4 + | Instruction::Iconst5 + | Instruction::Idiv + | Instruction::Iload0 + | Instruction::Iload1 + | Instruction::Iload2 + | Instruction::Iload3 + | Instruction::Imul + | Instruction::Ineg + | Instruction::Ior + | Instruction::Irem + | Instruction::Ireturn + | Instruction::Ishl + | Instruction::Ishr + | Instruction::Istore0 + | Instruction::Istore1 + | Instruction::Istore2 + | Instruction::Istore3 + | Instruction::Isub + | Instruction::Iushr + | Instruction::Ixor + | Instruction::L2d + | Instruction::L2f + | Instruction::L2i + | Instruction::Ladd + | Instruction::Laload + | Instruction::Land + | Instruction::Lastore + | Instruction::Lcmp + | Instruction::Lconst0 + | Instruction::Lconst1 + | Instruction::Ldiv + | Instruction::Lload0 + | Instruction::Lload1 + | Instruction::Lload2 + | Instruction::Lload3 + | Instruction::Lmul + | Instruction::Lneg + | Instruction::Lor + | Instruction::Lrem + | Instruction::Lreturn + | Instruction::Lshl + | Instruction::Lshr + | Instruction::Lstore0 + | Instruction::Lstore1 + | Instruction::Lstore2 + | Instruction::Lstore3 + | Instruction::Lsub + | Instruction::Lushr + | Instruction::Lxor + | Instruction::Monitorenter + | Instruction::Monitorexit + | Instruction::Pop + | Instruction::Pop2 + | Instruction::Return + | Instruction::Saload + | Instruction::Sastore + | Instruction::Swap => 1, + + // 2-byte instructions (1 byte operand) + Instruction::Aload(_) + | Instruction::Astore(_) + | Instruction::Bipush(_) + | Instruction::Dload(_) + | Instruction::Dstore(_) + | Instruction::Fload(_) + | Instruction::Fstore(_) + | Instruction::Iload(_) + | Instruction::Istore(_) + | Instruction::Ldc(_) + | Instruction::Lload(_) + | Instruction::Lstore(_) + | Instruction::Newarray(_) + | Instruction::Ret(_) => 2, + + // 3-byte instructions (2 byte operand) + Instruction::Anewarray(_) + | Instruction::Checkcast(_) + | Instruction::Getfield(_) + | Instruction::Getstatic(_) + | Instruction::Goto(_) + | Instruction::IfAcmpeq(_) + | Instruction::IfAcmpne(_) + | Instruction::IfIcmpeq(_) + | Instruction::IfIcmpne(_) + | Instruction::IfIcmplt(_) + | Instruction::IfIcmpge(_) + | Instruction::IfIcmpgt(_) + | Instruction::IfIcmple(_) + | Instruction::Ifeq(_) + | Instruction::Ifne(_) + | Instruction::Iflt(_) + | Instruction::Ifge(_) + | Instruction::Ifgt(_) + | Instruction::Ifle(_) + | Instruction::Ifnonnull(_) + | Instruction::Ifnull(_) + | Instruction::Instanceof(_) + | Instruction::Invokespecial(_) + | Instruction::Invokestatic(_) + | Instruction::Invokevirtual(_) + | Instruction::Jsr(_) + | Instruction::LdcW(_) + | Instruction::Ldc2W(_) + | Instruction::New(_) + | Instruction::Putfield(_) + | Instruction::Putstatic(_) + | Instruction::Sipush(_) + | Instruction::Iinc { .. } => 3, + + // 4-byte instructions + Instruction::Multianewarray { .. } => 4, + + // 5-byte instructions + Instruction::GotoW(_) + | Instruction::JsrW(_) + | Instruction::Invokedynamic { .. } + | Instruction::Invokeinterface { .. } => 5, + + // Wide instructions: 2 (magic) + 2 (index) = 4, except IincWide = 6 + Instruction::AloadWide(_) + | Instruction::AstoreWide(_) + | Instruction::DloadWide(_) + | Instruction::DstoreWide(_) + | Instruction::FloadWide(_) + | Instruction::FstoreWide(_) + | Instruction::IloadWide(_) + | Instruction::IstoreWide(_) + | Instruction::LloadWide(_) + | Instruction::LstoreWide(_) + | Instruction::RetWide(_) => 4, + + Instruction::IincWide { .. } => 6, + + // Variable-length: tableswitch + Instruction::Tableswitch { low, high, .. } => { + let padding = (4 - (address + 1) % 4) % 4; + // 1 (opcode) + padding + 4 (default) + 4 (low) + 4 (high) + 4*(high-low+1) + 1 + padding + 4 + 4 + 4 + 4 * ((*high - *low + 1) as u32) + } + + // Variable-length: lookupswitch + Instruction::Lookupswitch { npairs, .. } => { + let padding = (4 - (address + 1) % 4) % 4; + // 1 (opcode) + padding + 4 (default) + 4 (npairs) + 8*npairs + 1 + padding + 4 + 4 + 8 * npairs + } + } +} + +fn format_manifest(data: &[u8]) -> Vec { + match JarManifest::parse(data) { + Ok(manifest) => { + let mut lines = vec!["=== MANIFEST.MF ===".to_string(), String::new()]; + lines.push("Main Attributes:".to_string()); + for (key, value) in manifest.main_attributes.iter() { + lines.push(format!(" {key}: {value}")); + } + for (name, attrs) in &manifest.entries { + lines.push(String::new()); + lines.push(format!("Section: {name}")); + for (key, value) in attrs.iter() { + lines.push(format!(" {key}: {value}")); + } + } + lines + } + Err(e) => { + let mut lines = vec![format!("Failed to parse manifest: {e}"), String::new()]; + lines.extend(format_text(data)); + lines + } + } +} + +fn format_nested_jar(path: &str, data: &[u8]) -> Vec { + match JarFile::from_bytes(data) { + Ok(nested) => { + let entry_count = nested.entry_names().count(); + let class_count = nested.class_names().count(); + let mut lines = vec![ + format!("=== Nested JAR: {path} ==="), + format!("Entries: {entry_count}"), + format!("Classes: {class_count}"), + ]; + + // Show manifest if present + if let Ok(Some(manifest)) = nested.manifest() { + lines.push(String::new()); + lines.push("Manifest:".to_string()); + for (key, value) in manifest.main_attributes.iter() { + lines.push(format!(" {key}: {value}")); + } + } + + lines.push(String::new()); + lines.push("Entry listing:".to_string()); + for name in nested.entry_names() { + lines.push(format!(" {name}")); + } + lines + } + Err(e) => vec![format!("Failed to open nested JAR: {e}")], + } +} + +fn format_text(data: &[u8]) -> Vec { + String::from_utf8_lossy(data) + .lines() + .map(|l| l.to_string()) + .collect() +} + +fn format_hex(data: &[u8]) -> Vec { + let mut lines = Vec::new(); + for (offset, chunk) in data.chunks(16).enumerate() { + let hex: Vec = chunk.iter().map(|b| format!("{b:02x}")).collect(); + let ascii: String = chunk + .iter() + .map(|&b| { + if b.is_ascii_graphic() || b == b' ' { + b as char + } else { + '.' + } + }) + .collect(); + + // Pad hex to fixed width + let hex_str = if hex.len() < 16 { + let mut s = hex.join(" "); + for _ in hex.len()..16 { + s.push_str(" "); + } + s + } else { + hex.join(" ") + }; + + lines.push(format!("{:08x} {hex_str} |{ascii}|", offset * 16)); + } + lines +} + +fn load_entry_content(jar: &JarFile, path: &str) -> (String, Vec) { + if path.ends_with(".class") { + let title = path.to_string(); + let content = format_class(jar, path); + return (title, content); + } + + let data = match jar.get_entry(path) { + Some(d) => d, + None => return (path.to_string(), vec![format!("Entry not found: {path}")]), + }; + + if path == "META-INF/MANIFEST.MF" || path.ends_with("/MANIFEST.MF") { + return (path.to_string(), format_manifest(data)); + } + + if path.ends_with(".jar") { + return (path.to_string(), format_nested_jar(path, data)); + } + + // Try text for common text extensions + let text_exts = [ + ".properties", + ".xml", + ".json", + ".txt", + ".yml", + ".yaml", + ".md", + ".html", + ".css", + ".js", + ".MF", + ".idx", + ".factories", + ".imports", + ".cfg", + ".conf", + ".toml", + ".ini", + ".sql", + ".sh", + ".bat", + ".gradle", + ".kt", + ".java", + ".scala", + ".groovy", + ]; + + if text_exts.iter().any(|ext| path.ends_with(ext)) { + return (path.to_string(), format_text(data)); + } + + // Heuristic: try UTF-8, fall back to hex + if let Ok(text) = std::str::from_utf8(data) { + if text + .chars() + .take(512) + .all(|c| !c.is_control() || c == '\n' || c == '\r' || c == '\t') + { + return ( + path.to_string(), + text.lines().map(|l| l.to_string()).collect(), + ); + } + } + + (format!("{path} (hex)"), format_hex(data)) +} + +fn extract_methods(jar: &JarFile, entry_path: &str) -> Vec<(String, String)> { + let data = match jar.get_entry(entry_path) { + Some(d) => d, + None => return Vec::new(), + }; + let cf = match ClassFile::read(&mut Cursor::new(data)) { + Ok(c) => c, + Err(_) => return Vec::new(), + }; + cf.methods + .iter() + .map(|m| { + let name = get_utf8(&cf.const_pool, m.name_index); + let desc = get_utf8(&cf.const_pool, m.descriptor_index); + let access = format_method_access(m.access_flags); + let display = format!("{} {}{}", access, name, descriptor_to_readable(&desc)); + (name, display.trim_start().to_string()) + }) + .collect() +} + +// --------------------------------------------------------------------------- +// App implementation +// --------------------------------------------------------------------------- + +impl<'a> App<'a> { + fn new(jar: JarFile, jar_path: String) -> Self { + let spring_format = detect_format(&jar); + let tree = build_tree(&jar); + let mut viewer = TextArea::default(); + viewer.set_cursor_line_style(Style::default()); + viewer.set_line_number_style(Style::default().fg(Color::DarkGray)); + + let mut app = App { + jar, + jar_path, + spring_format, + tree, + tree_selected: 0, + tree_scroll: 0, + viewer, + vim_mode: VimMode::Normal, + search_buffer: String::new(), + focus: Focus::Tree, + viewer_title: "Viewer".to_string(), + loaded_entry: None, + status_message: String::new(), + should_quit: false, + edit_state: None, + has_unsaved_changes: false, + }; + + // Build initial status + app.update_status(); + app + } + + fn update_status(&mut self) { + // Edit mode status takes priority + match &self.edit_state { + Some(EditState::SelectMethod { .. }) => { + self.status_message = + " SELECT METHOD | j/k:navigate Enter:select Esc:cancel".to_string(); + return; + } + Some(EditState::EditCode { .. }) => { + self.status_message = + " EDIT | Ctrl+S:replace Ctrl+P:prepend Esc:cancel".to_string(); + return; + } + None => {} + } + + let mode_str = match &self.vim_mode { + VimMode::Normal => "NORMAL", + VimMode::Search => "SEARCH", + VimMode::Pending(_) => "PENDING", + }; + let focus_str = match self.focus { + Focus::Tree => "TREE", + Focus::Viewer => "VIEW", + }; + let spring = match self.spring_format { + Some(SpringBootFormat::Jar) => " [Spring Boot JAR]", + Some(SpringBootFormat::War) => " [Spring Boot WAR]", + None => "", + }; + let modified = if self.has_unsaved_changes { + " [modified]" + } else { + "" + }; + self.status_message = format!( + " {mode_str} | {focus_str}{spring}{modified} | Tab:switch e:edit W:save hjkl:move q:quit" + ); + } + + fn load_selected_entry(&mut self) { + let vis = visible_indices(&self.tree); + if vis.is_empty() { + return; + } + let idx = vis[self.tree_selected.min(vis.len() - 1)]; + let node = &self.tree[idx]; + + if node.is_dir { + return; + } + + let path = match &node.entry_path { + Some(p) => p.clone(), + None => return, + }; + + // Skip if already loaded + if self.loaded_entry.as_deref() == Some(&path) { + self.focus = Focus::Viewer; + self.update_status(); + return; + } + + let (title, content) = load_entry_content(&self.jar, &path); + self.viewer_title = title; + self.loaded_entry = Some(path); + + // Load content into TextArea + let lines: Vec = if content.is_empty() { + vec!["(empty)".to_string()] + } else { + content + }; + self.viewer = TextArea::new(lines); + self.viewer + .set_cursor_line_style(Style::default().bg(Color::DarkGray)); + self.viewer + .set_line_number_style(Style::default().fg(Color::DarkGray)); + + self.focus = Focus::Viewer; + self.vim_mode = VimMode::Normal; + self.update_status(); + } + + fn toggle_dir(&mut self) { + let vis = visible_indices(&self.tree); + if vis.is_empty() { + return; + } + let idx = vis[self.tree_selected.min(vis.len() - 1)]; + if self.tree[idx].is_dir { + self.tree[idx].expanded = !self.tree[idx].expanded; + } + } + + fn expand_dir(&mut self) { + let vis = visible_indices(&self.tree); + if vis.is_empty() { + return; + } + let idx = vis[self.tree_selected.min(vis.len() - 1)]; + if self.tree[idx].is_dir { + self.tree[idx].expanded = true; + } + } + + fn collapse_dir(&mut self) { + let vis = visible_indices(&self.tree); + if vis.is_empty() { + return; + } + let idx = vis[self.tree_selected.min(vis.len() - 1)]; + if self.tree[idx].is_dir { + self.tree[idx].expanded = false; + } else { + // Find parent directory and collapse it + let node_depth = self.tree[idx].depth; + if node_depth > 0 { + // Walk backwards through visible items to find parent + if self.tree_selected > 0 { + for check in (0..self.tree_selected).rev() { + let check_idx = vis[check]; + if self.tree[check_idx].is_dir && self.tree[check_idx].depth < node_depth { + self.tree[check_idx].expanded = false; + self.tree_selected = check; + break; + } + } + } + } + } + } + + fn enter_edit_mode(&mut self) { + let entry_path = match &self.loaded_entry { + Some(p) if p.ends_with(".class") => p.clone(), + _ => return, + }; + let methods = extract_methods(&self.jar, &entry_path); + if methods.is_empty() { + self.status_message = " No methods found in class".to_string(); + return; + } + self.edit_state = Some(EditState::SelectMethod { + entry_path, + methods, + selected: 0, + }); + self.update_status(); + } + + fn apply_edit(&mut self, prepend: bool) { + let (entry_path, method_name, source) = match &self.edit_state { + Some(EditState::EditCode { + entry_path, + method_name, + editor, + .. + }) => { + let src = editor.lines().join("\n"); + (entry_path.clone(), method_name.clone(), src) + } + _ => return, + }; + + let data = match self.jar.get_entry(&entry_path) { + Some(d) => d.to_vec(), + None => { + if let Some(EditState::EditCode { error_message, .. }) = &mut self.edit_state { + *error_message = Some(format!("Entry not found: {entry_path}")); + } + return; + } + }; + + let mut cf = match ClassFile::read(&mut Cursor::new(&data)) { + Ok(c) => c, + Err(e) => { + if let Some(EditState::EditCode { error_message, .. }) = &mut self.edit_state { + *error_message = Some(format!("Failed to parse class: {e}")); + } + return; + } + }; + + let opts = CompileOptions::default(); + let result = if prepend { + prepend_method_body(&source, &mut cf, &method_name, None, &opts) + } else { + compile_method_body(&source, &mut cf, &method_name, None, &opts) + }; + + match result { + Ok(()) => match cf.to_bytes() { + Ok(bytes) => { + self.jar.set_entry(&entry_path, bytes); + self.has_unsaved_changes = true; + let action = if prepend { "Prepended to" } else { "Replaced" }; + self.status_message = format!(" {action} '{method_name}' | W:save JAR"); + self.edit_state = None; + self.loaded_entry = None; + self.load_selected_entry(); + } + Err(e) => { + if let Some(EditState::EditCode { error_message, .. }) = &mut self.edit_state { + *error_message = Some(format!("Failed to serialize class: {e}")); + } + } + }, + Err(e) => { + if let Some(EditState::EditCode { error_message, .. }) = &mut self.edit_state { + *error_message = Some(format!("{e}")); + } + } + } + } + + fn save_jar(&mut self) { + if !self.has_unsaved_changes { + self.status_message = " No unsaved changes".to_string(); + return; + } + let output_path = if self.jar_path.ends_with(".jar") { + self.jar_path.replace(".jar", ".patched.jar") + } else { + format!("{}.patched", self.jar_path) + }; + match self.jar.save(&output_path) { + Ok(()) => { + self.has_unsaved_changes = false; + self.status_message = format!(" Saved to {output_path}"); + } + Err(e) => { + self.status_message = format!(" Save failed: {e}"); + } + } + } +} + +// --------------------------------------------------------------------------- +// Key handling +// --------------------------------------------------------------------------- + +fn handle_key_event(app: &mut App, key: KeyEvent) { + // Global shortcuts + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') { + app.should_quit = true; + return; + } + + // Edit mode intercepts all input + if app.edit_state.is_some() { + handle_edit_input(app, key); + app.update_status(); + return; + } + + match app.focus { + Focus::Tree => handle_tree_input(app, key), + Focus::Viewer => handle_vim_input(app, key), + } + app.update_status(); +} + +fn handle_tree_input(app: &mut App, key: KeyEvent) { + let vis = visible_indices(&app.tree); + if vis.is_empty() { + return; + } + let max = vis.len().saturating_sub(1); + + match &app.vim_mode { + VimMode::Pending('g') => { + app.vim_mode = VimMode::Normal; + if key.code == KeyCode::Char('g') { + app.tree_selected = 0; + } + return; + } + VimMode::Pending(_) => { + app.vim_mode = VimMode::Normal; + return; + } + _ => {} + } + + match key.code { + KeyCode::Char('q') => app.should_quit = true, + KeyCode::Char('j') | KeyCode::Down => { + app.tree_selected = (app.tree_selected + 1).min(max); + } + KeyCode::Char('k') | KeyCode::Up => { + app.tree_selected = app.tree_selected.saturating_sub(1); + } + KeyCode::Char('l') | KeyCode::Right | KeyCode::Char(' ') => { + let idx = vis[app.tree_selected.min(max)]; + if app.tree[idx].is_dir { + app.expand_dir(); + } else { + app.load_selected_entry(); + } + } + KeyCode::Char('h') | KeyCode::Left => app.collapse_dir(), + KeyCode::Enter => { + let idx = vis[app.tree_selected.min(max)]; + if app.tree[idx].is_dir { + app.toggle_dir(); + } else { + app.load_selected_entry(); + } + } + KeyCode::Char('g') => { + app.vim_mode = VimMode::Pending('g'); + } + KeyCode::Char('G') => { + app.tree_selected = max; + } + KeyCode::Char('e') => { + app.enter_edit_mode(); + } + KeyCode::Char('W') => { + app.save_jar(); + } + KeyCode::Tab => { + if app.loaded_entry.is_some() { + app.focus = Focus::Viewer; + } + } + _ => {} + } +} + +fn handle_edit_input(app: &mut App, key: KeyEvent) { + // Determine which edit sub-state we're in + let is_select = matches!(app.edit_state, Some(EditState::SelectMethod { .. })); + + if is_select { + handle_edit_select(app, key); + } else { + handle_edit_code(app, key); + } +} + +fn handle_edit_select(app: &mut App, key: KeyEvent) { + let method_count = match &app.edit_state { + Some(EditState::SelectMethod { methods, .. }) => methods.len(), + _ => return, + }; + let max = method_count.saturating_sub(1); + + match key.code { + KeyCode::Char('j') | KeyCode::Down => { + if let Some(EditState::SelectMethod { selected, .. }) = &mut app.edit_state { + *selected = (*selected + 1).min(max); + } + } + KeyCode::Char('k') | KeyCode::Up => { + if let Some(EditState::SelectMethod { selected, .. }) = &mut app.edit_state { + *selected = selected.saturating_sub(1); + } + } + KeyCode::Enter => { + // Transition to EditCode + if let Some(EditState::SelectMethod { + entry_path, + methods, + selected, + }) = app.edit_state.take() + { + let (method_name, _) = &methods[selected]; + let mut editor = + TextArea::new(vec!["{ ".to_string(), " ".to_string(), "}".to_string()]); + editor.set_cursor_line_style(Style::default().bg(Color::DarkGray)); + editor.set_cursor_style(Style::default().bg(Color::White).fg(Color::Black)); + editor.set_line_number_style(Style::default().fg(Color::DarkGray)); + // Position cursor on the middle line + editor.move_cursor(CursorMove::Down); + editor.move_cursor(CursorMove::End); + + app.edit_state = Some(EditState::EditCode { + entry_path, + method_name: method_name.clone(), + editor, + error_message: None, + }); + } + } + KeyCode::Esc => { + app.edit_state = None; + } + _ => {} + } +} + +fn handle_edit_code(app: &mut App, key: KeyEvent) { + // Ctrl+S → replace, Ctrl+P → prepend, Escape → cancel + if key.modifiers.contains(KeyModifiers::CONTROL) { + match key.code { + KeyCode::Char('s') => { + app.apply_edit(false); + return; + } + KeyCode::Char('p') => { + app.apply_edit(true); + return; + } + _ => {} + } + } + + if key.code == KeyCode::Esc { + app.edit_state = None; + return; + } + + // Forward to TextArea editor + if let Some(EditState::EditCode { + editor, + error_message, + .. + }) = &mut app.edit_state + { + editor.input(key); + // Clear error on new input + if error_message.is_some() { + *error_message = None; + } + } +} + +fn handle_vim_input(app: &mut App, key: KeyEvent) { + match &app.vim_mode { + VimMode::Normal => handle_vim_normal(app, key), + VimMode::Search => handle_vim_search(app, key), + VimMode::Pending(c) => { + let c = *c; + handle_vim_pending(app, key, c); + } + } +} + +fn handle_vim_normal(app: &mut App, key: KeyEvent) { + match key.code { + KeyCode::Char('q') | KeyCode::Tab => { + app.focus = Focus::Tree; + } + KeyCode::Char('j') | KeyCode::Down => { + app.viewer.move_cursor(CursorMove::Down); + } + KeyCode::Char('k') | KeyCode::Up => { + app.viewer.move_cursor(CursorMove::Up); + } + KeyCode::Char('h') | KeyCode::Left => { + app.viewer.move_cursor(CursorMove::Back); + } + KeyCode::Char('l') | KeyCode::Right => { + app.viewer.move_cursor(CursorMove::Forward); + } + KeyCode::Char('w') => { + app.viewer.move_cursor(CursorMove::WordForward); + } + KeyCode::Char('b') => { + app.viewer.move_cursor(CursorMove::WordBack); + } + KeyCode::Char('0') => { + app.viewer.move_cursor(CursorMove::Head); + } + KeyCode::Char('$') => { + app.viewer.move_cursor(CursorMove::End); + } + KeyCode::Char('g') => { + app.vim_mode = VimMode::Pending('g'); + } + KeyCode::Char('G') => { + app.viewer.move_cursor(CursorMove::Bottom); + } + KeyCode::Char('d') if key.modifiers.contains(KeyModifiers::CONTROL) => { + app.viewer.scroll((10, 0)); + } + KeyCode::Char('u') if key.modifiers.contains(KeyModifiers::CONTROL) => { + app.viewer.scroll((-10, 0)); + } + KeyCode::Char('/') => { + app.vim_mode = VimMode::Search; + app.search_buffer.clear(); + } + KeyCode::Char('n') => { + app.viewer.search_forward(false); + } + KeyCode::Char('N') => { + app.viewer.search_back(false); + } + _ => {} + } +} + +fn handle_vim_search(app: &mut App, key: KeyEvent) { + match key.code { + KeyCode::Esc => { + app.vim_mode = VimMode::Normal; + app.search_buffer.clear(); + } + KeyCode::Enter => { + if !app.search_buffer.is_empty() { + app.viewer.set_search_pattern(&app.search_buffer).ok(); + app.viewer.search_forward(false); + } + app.vim_mode = VimMode::Normal; + } + KeyCode::Backspace => { + app.search_buffer.pop(); + } + KeyCode::Char(c) => { + app.search_buffer.push(c); + } + _ => {} + } +} + +fn handle_vim_pending(app: &mut App, key: KeyEvent, pending: char) { + app.vim_mode = VimMode::Normal; + if pending == 'g' && key.code == KeyCode::Char('g') { + app.viewer.move_cursor(CursorMove::Top); + } + // Any other combo just cancels the pending state +} + +// --------------------------------------------------------------------------- +// Rendering +// --------------------------------------------------------------------------- + +fn render(app: &mut App, frame: &mut ratatui::Frame) { + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(1), Constraint::Length(1)]) + .split(frame.area()); + + let main_area = chunks[0]; + let status_area = chunks[1]; + + let main_chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(30), Constraint::Percentage(70)]) + .split(main_area); + + render_tree(app, frame, main_chunks[0]); + render_viewer(app, frame, main_chunks[1]); + render_status(app, frame, status_area); +} + +fn render_tree(app: &mut App, frame: &mut ratatui::Frame, area: Rect) { + let vis = visible_indices(&app.tree); + let selected = app.tree_selected.min(vis.len().saturating_sub(1)); + + // Adjust scroll to keep selected visible + let inner_height = area.height.saturating_sub(2) as usize; // account for borders + if selected < app.tree_scroll { + app.tree_scroll = selected; + } else if selected >= app.tree_scroll + inner_height { + app.tree_scroll = selected - inner_height + 1; + } + + let items: Vec = vis + .iter() + .enumerate() + .skip(app.tree_scroll) + .take(inner_height) + .map(|(i, &idx)| { + let node = &app.tree[idx]; + let indent = " ".repeat(node.depth); + let icon = if node.is_dir { + if node.expanded { "[-] " } else { "[+] " } + } else if node.label.ends_with(".class") { + "> " + } else { + " " + }; + + let color = if node.is_dir { + Color::Yellow + } else if node.label.ends_with(".class") { + Color::Cyan + } else if node.label.ends_with(".jar") { + Color::Green + } else { + Color::White + }; + + let style = if i == selected { + if app.focus == Focus::Tree { + Style::default() + .fg(Color::Black) + .bg(Color::White) + .add_modifier(Modifier::BOLD) + } else { + Style::default().fg(color).bg(Color::DarkGray) + } + } else { + Style::default().fg(color) + }; + + ListItem::new(Line::from(Span::styled( + format!("{indent}{icon}{}", node.label), + style, + ))) + }) + .collect(); + + let title = match app.spring_format { + Some(SpringBootFormat::Jar) => " JAR Explorer [Spring Boot] ", + Some(SpringBootFormat::War) => " JAR Explorer [Spring Boot WAR] ", + None => " JAR Explorer ", + }; + + let tree_block = Block::default() + .borders(Borders::ALL) + .title(title) + .border_style(if app.focus == Focus::Tree { + Style::default().fg(Color::Cyan) + } else { + Style::default().fg(Color::DarkGray) + }); + + let list = List::new(items).block(tree_block); + frame.render_widget(list, area); +} + +fn render_viewer(app: &mut App, frame: &mut ratatui::Frame, area: Rect) { + // Edit mode rendering + match &mut app.edit_state { + Some(EditState::SelectMethod { + methods, selected, .. + }) => { + render_method_selector(frame, area, methods, *selected); + return; + } + Some(EditState::EditCode { + method_name, + editor, + error_message, + .. + }) => { + render_code_editor(frame, area, method_name, editor, error_message.as_deref()); + return; + } + None => {} + } + + let title = if app.vim_mode == VimMode::Search { + format!(" {} | /{} ", app.viewer_title, app.search_buffer) + } else { + format!(" {} ", app.viewer_title) + }; + + let block = Block::default() + .borders(Borders::ALL) + .title(title) + .border_style(if app.focus == Focus::Viewer { + Style::default().fg(Color::Cyan) + } else { + Style::default().fg(Color::DarkGray) + }); + + app.viewer.set_block(block); + + if app.focus == Focus::Viewer { + app.viewer + .set_cursor_line_style(Style::default().bg(Color::DarkGray)); + app.viewer + .set_cursor_style(Style::default().bg(Color::White).fg(Color::Black)); + } else { + app.viewer.set_cursor_line_style(Style::default()); + app.viewer.set_cursor_style(Style::default()); + } + + frame.render_widget(&app.viewer, area); +} + +fn render_method_selector( + frame: &mut ratatui::Frame, + area: Rect, + methods: &[(String, String)], + selected: usize, +) { + let inner_height = area.height.saturating_sub(2) as usize; + let scroll = if selected >= inner_height { + selected - inner_height + 1 + } else { + 0 + }; + + let items: Vec = methods + .iter() + .enumerate() + .skip(scroll) + .take(inner_height) + .map(|(i, (_, display))| { + let style = if i == selected { + Style::default() + .fg(Color::Black) + .bg(Color::White) + .add_modifier(Modifier::BOLD) + } else { + Style::default().fg(Color::White) + }; + ListItem::new(Line::from(Span::styled(format!(" {display}"), style))) + }) + .collect(); + + let block = Block::default() + .borders(Borders::ALL) + .title(" Select Method ") + .border_style(Style::default().fg(Color::Yellow)); + + let list = List::new(items).block(block); + frame.render_widget(list, area); +} + +fn render_code_editor( + frame: &mut ratatui::Frame, + area: Rect, + method_name: &str, + editor: &mut TextArea, + error_message: Option<&str>, +) { + if let Some(err) = error_message { + // Split: editor on top, error on bottom + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(3), Constraint::Length(3)]) + .split(area); + + let block = Block::default() + .borders(Borders::ALL) + .title(format!(" Edit: {method_name} ")) + .border_style(Style::default().fg(Color::Green)); + editor.set_block(block); + editor.set_cursor_line_style(Style::default().bg(Color::DarkGray)); + editor.set_cursor_style(Style::default().bg(Color::White).fg(Color::Black)); + editor.set_line_number_style(Style::default().fg(Color::DarkGray)); + frame.render_widget(&*editor, chunks[0]); + + let error_block = Block::default() + .borders(Borders::ALL) + .title(" Error ") + .border_style(Style::default().fg(Color::Red)); + let error_text = Paragraph::new(Line::from(Span::styled( + err, + Style::default().fg(Color::Red), + ))) + .block(error_block); + frame.render_widget(error_text, chunks[1]); + } else { + let block = Block::default() + .borders(Borders::ALL) + .title(format!(" Edit: {method_name} ")) + .border_style(Style::default().fg(Color::Green)); + editor.set_block(block); + editor.set_cursor_line_style(Style::default().bg(Color::DarkGray)); + editor.set_cursor_style(Style::default().bg(Color::White).fg(Color::Black)); + editor.set_line_number_style(Style::default().fg(Color::DarkGray)); + frame.render_widget(&*editor, area); + } +} + +fn render_status(app: &App, frame: &mut ratatui::Frame, area: Rect) { + let status = Paragraph::new(Line::from(Span::styled( + &app.status_message, + Style::default().fg(Color::Black).bg(Color::White), + ))); + frame.render_widget(status, area); +} + +// --------------------------------------------------------------------------- +// Main +// --------------------------------------------------------------------------- + +fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: jar_explorer "); + std::process::exit(1); + } + let path = &args[1]; + + let jar = JarFile::open(path).map_err(|e| format!("Failed to open {path}: {e}"))?; + + // Terminal setup + terminal::enable_raw_mode()?; + let mut stdout = io::stdout(); + crossterm::execute!(stdout, EnterAlternateScreen)?; + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::new(backend)?; + + let mut app = App::new(jar, path.clone()); + + // Main loop + loop { + terminal.draw(|frame| render(&mut app, frame))?; + + if event::poll(std::time::Duration::from_millis(50))? { + if let Event::Key(key) = event::read()? { + handle_key_event(&mut app, key); + } + } + + if app.should_quit { + break; + } + } + + // Cleanup + terminal::disable_raw_mode()?; + crossterm::execute!(terminal.backend_mut(), LeaveAlternateScreen)?; + terminal.show_cursor()?; + + Ok(()) +} diff --git a/examples/jar_patch.rs b/examples/jar_patch.rs new file mode 100644 index 0000000..58ae458 --- /dev/null +++ b/examples/jar_patch.rs @@ -0,0 +1,110 @@ +//! Example: patch methods inside a JAR file using the jar-patch feature. +//! +//! This example: +//! 1. Compiles a Java class with `javac` +//! 2. Packs the `.class` file into a JAR using `JarFile` +//! 3. Patches method bodies using `patch_jar!` +//! 4. Saves the modified JAR and runs it with `java` +//! +//! Run with: +//! cargo run --example jar_patch --features jar-patch + +use std::fs; +use std::process::Command; + +use classfile_parser::jar_utils::JarFile; +use classfile_parser::{patch_jar, patch_jar_method}; + +fn main() { + // ── Step 1: Create and compile a Java class ────────────────────────── + let tmp_dir = std::env::temp_dir().join("classfile_jar_patch_example"); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let java_src = tmp_dir.join("HelloWorld.java"); + fs::write( + &java_src, + r#" +public class HelloWorld { + public static void greet() { + System.out.println("original greet"); + } + + public static void main(String[] args) { + System.out.println("original main"); + greet(); + } +} +"#, + ) + .unwrap(); + + let javac = Command::new("javac") + .arg("-d") + .arg(&tmp_dir) + .arg(&java_src) + .output() + .expect("javac not found — make sure a JDK is on your PATH"); + assert!( + javac.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&javac.stderr) + ); + println!("Compiled HelloWorld.java"); + + // ── Step 2: Pack into a JAR ────────────────────────────────────────── + let class_bytes = fs::read(tmp_dir.join("HelloWorld.class")).unwrap(); + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + println!("Packed into JAR ({} entries)", jar.entry_names().count()); + + // ── Step 3a: Patch a single method ─────────────────────────────────── + patch_jar_method!( + jar, + "HelloWorld.class", + "greet", + r#"{ + System.out.println("patched greet!"); + }"# + ) + .unwrap(); + println!("Patched greet()"); + + // ── Step 3b: Patch multiple methods across classes ─────────────────── + // + // patch_jar! batches by class — each class is parsed once, all its + // methods are patched, and the class is written back once. + patch_jar!(jar, { + "HelloWorld.class" => { + "main" => r#"{ + System.out.println("patched main!"); + greet(); + }"#, + }, + }) + .unwrap(); + println!("Patched main()"); + + // ── Step 4: Save and run ───────────────────────────────────────────── + let jar_path = tmp_dir.join("patched.jar"); + jar.save(&jar_path).unwrap(); + println!("Saved {}\n", jar_path.display()); + + let run = Command::new("java") + .arg("-cp") + .arg(&jar_path) + .arg("HelloWorld") + .output() + .expect("java not found"); + if run.status.success() { + println!("{}", String::from_utf8_lossy(&run.stdout).trim()); + } else { + eprintln!( + "java failed: {}", + String::from_utf8_lossy(&run.stderr).trim() + ); + std::process::exit(1); + } + + let _ = fs::remove_dir_all(&tmp_dir); +} diff --git a/java-assets/compiled-classes/NestExample$Inner.class b/java-assets/compiled-classes/NestExample$Inner.class new file mode 100644 index 0000000..aaa0ce8 Binary files /dev/null and b/java-assets/compiled-classes/NestExample$Inner.class differ diff --git a/java-assets/compiled-classes/NestExample.class b/java-assets/compiled-classes/NestExample.class new file mode 100644 index 0000000..857f51d Binary files /dev/null and b/java-assets/compiled-classes/NestExample.class differ diff --git a/java-assets/compiled-classes/RecordExample.class b/java-assets/compiled-classes/RecordExample.class new file mode 100644 index 0000000..b65d946 Binary files /dev/null and b/java-assets/compiled-classes/RecordExample.class differ diff --git a/java-assets/compiled-classes/SealedChild1.class b/java-assets/compiled-classes/SealedChild1.class new file mode 100644 index 0000000..8886a70 Binary files /dev/null and b/java-assets/compiled-classes/SealedChild1.class differ diff --git a/java-assets/compiled-classes/SealedChild2.class b/java-assets/compiled-classes/SealedChild2.class new file mode 100644 index 0000000..acc86e2 Binary files /dev/null and b/java-assets/compiled-classes/SealedChild2.class differ diff --git a/java-assets/compiled-classes/SealedExample.class b/java-assets/compiled-classes/SealedExample.class new file mode 100644 index 0000000..8d6c512 Binary files /dev/null and b/java-assets/compiled-classes/SealedExample.class differ diff --git a/java-assets/compiled-classes/com/some/Thing.class b/java-assets/compiled-classes/com/some/Thing.class new file mode 100644 index 0000000..5254631 Binary files /dev/null and b/java-assets/compiled-classes/com/some/Thing.class differ diff --git a/java-assets/compiled-classes/module-info.class b/java-assets/compiled-classes/module-info.class new file mode 100644 index 0000000..9cc9a8a Binary files /dev/null and b/java-assets/compiled-classes/module-info.class differ diff --git a/java-assets/src/NestExample.java b/java-assets/src/NestExample.java new file mode 100644 index 0000000..5e057b3 --- /dev/null +++ b/java-assets/src/NestExample.java @@ -0,0 +1,9 @@ +public class NestExample { + private int value = 42; + + class Inner { + int getValue() { + return value; + } + } +} diff --git a/java-assets/src/ParamAccess.java b/java-assets/src/ParamAccess.java new file mode 100644 index 0000000..d894aa9 --- /dev/null +++ b/java-assets/src/ParamAccess.java @@ -0,0 +1,13 @@ +public class ParamAccess { + public static void main(String[] args) { + // Call other methods so their descriptors are in the constant pool + wideParams(0, 0L, ""); + new ParamAccess().instanceMethod(""); + } + public static void wideParams(int a, long b, String c) { + System.out.println("original"); + } + public void instanceMethod(String name) { + System.out.println("original"); + } +} diff --git a/java-assets/src/PrependTest.java b/java-assets/src/PrependTest.java new file mode 100644 index 0000000..49bdf23 --- /dev/null +++ b/java-assets/src/PrependTest.java @@ -0,0 +1,26 @@ +public class PrependTest { + public static void main(String[] args) { + // Call all methods so their descriptors are in the constant pool + withParams(""); + withLocal(); + withTryCatch(); + withBranch(0); + } + public static void withParams(String name) { + System.out.println("hello " + name); + } + public static void withLocal() { + int x = 10; + System.out.println("x=" + x); + } + public static void withTryCatch() { + try { + System.out.println("try"); + } catch (Exception e) { + System.out.println("catch"); + } + } + public static void withBranch(int n) { + System.out.println("n=" + n); + } +} diff --git a/java-assets/src/RecordExample.java b/java-assets/src/RecordExample.java new file mode 100644 index 0000000..def914d --- /dev/null +++ b/java-assets/src/RecordExample.java @@ -0,0 +1 @@ +public record RecordExample(int x, String name) {} diff --git a/java-assets/src/SealedExample.java b/java-assets/src/SealedExample.java new file mode 100644 index 0000000..eccbb77 --- /dev/null +++ b/java-assets/src/SealedExample.java @@ -0,0 +1,4 @@ +public sealed class SealedExample permits SealedChild1, SealedChild2 {} + +final class SealedChild1 extends SealedExample {} +final class SealedChild2 extends SealedExample {} diff --git a/java-assets/src/com/some/Thing.java b/java-assets/src/com/some/Thing.java new file mode 100644 index 0000000..26dee72 --- /dev/null +++ b/java-assets/src/com/some/Thing.java @@ -0,0 +1,3 @@ +package com.some; + +public class Thing {} diff --git a/java-assets/src/module-info.java b/java-assets/src/module-info.java new file mode 100644 index 0000000..bb0aef9 --- /dev/null +++ b/java-assets/src/module-info.java @@ -0,0 +1,3 @@ +module my.module { + exports com.some; +} diff --git a/src/attribute_info/mod.rs b/src/attribute_info/mod.rs index 49de015..6108ca4 100644 --- a/src/attribute_info/mod.rs +++ b/src/attribute_info/mod.rs @@ -1,25 +1,3 @@ -mod parser; mod types; pub use self::types::*; - -pub use self::parser::attribute_parser; -pub use self::parser::bootstrap_methods_attribute_parser; -pub use self::parser::code_attribute_parser; -pub use self::parser::constant_value_attribute_parser; -pub use self::parser::element_value_parser; -pub use self::parser::enclosing_method_attribute_parser; -pub use self::parser::exceptions_attribute_parser; -pub use self::parser::inner_classes_attribute_parser; -pub use self::parser::line_number_table_attribute_parser; -pub use self::parser::method_parameters_attribute_parser; -pub use self::parser::runtime_invisible_annotations_attribute_parser; -pub use self::parser::runtime_invisible_parameter_annotations_attribute_parser; -pub use self::parser::runtime_invisible_type_annotations_attribute_parser; -pub use self::parser::runtime_visible_annotations_attribute_parser; -pub use self::parser::runtime_visible_parameter_annotations_attribute_parser; -pub use self::parser::runtime_visible_type_annotations_attribute_parser; -pub use self::parser::signature_attribute_parser; -pub use self::parser::source_debug_extension_parser; -pub use self::parser::sourcefile_attribute_parser; -pub use self::parser::stack_map_table_attribute_parser; diff --git a/src/attribute_info/parser.rs b/src/attribute_info/parser.rs deleted file mode 100644 index 7226f3c..0000000 --- a/src/attribute_info/parser.rs +++ /dev/null @@ -1,699 +0,0 @@ -use nom::{ - Err as BaseErr, - bytes::complete::take, - combinator::{map, success}, - error::{Error, ErrorKind}, - multi::count, - number::complete::{be_u8, be_u16, be_u32}, -}; - -use crate::attribute_info::types::StackMapFrame::*; -use crate::attribute_info::*; - -// Using a type alias here evades a Clippy warning about complex types. -type Err = BaseErr>; - -pub fn attribute_parser(input: &[u8]) -> Result<(&[u8], AttributeInfo), Err<&[u8]>> { - let (input, attribute_name_index) = be_u16(input)?; - let (input, attribute_length) = be_u32(input)?; - let (input, info) = take(attribute_length)(input)?; - Ok(( - input, - AttributeInfo { - attribute_name_index, - attribute_length, - info: info.to_owned(), - }, - )) -} - -pub fn exception_entry_parser(input: &[u8]) -> Result<(&[u8], ExceptionEntry), Err<&[u8]>> { - let (input, start_pc) = be_u16(input)?; - let (input, end_pc) = be_u16(input)?; - let (input, handler_pc) = be_u16(input)?; - let (input, catch_type) = be_u16(input)?; - Ok(( - input, - ExceptionEntry { - start_pc, - end_pc, - handler_pc, - catch_type, - }, - )) -} - -pub fn code_attribute_parser(input: &[u8]) -> Result<(&[u8], CodeAttribute), Err<&[u8]>> { - let (input, max_stack) = be_u16(input)?; - let (input, max_locals) = be_u16(input)?; - let (input, code_length) = be_u32(input)?; - let (input, code) = take(code_length)(input)?; - let (input, exception_table_length) = be_u16(input)?; - let (input, exception_table) = - count(exception_entry_parser, exception_table_length as usize)(input)?; - let (input, attributes_count) = be_u16(input)?; - let (input, attributes) = count(attribute_parser, attributes_count as usize)(input)?; - Ok(( - input, - CodeAttribute { - max_stack, - max_locals, - code_length, - code: code.to_owned(), - exception_table_length, - exception_table, - attributes_count, - attributes, - }, - )) -} - -pub fn method_parameters_attribute_parser( - input: &[u8], -) -> Result<(&[u8], MethodParametersAttribute), Err<&[u8]>> { - let (input, parameters_count) = be_u8(input)?; - let (input, parameters) = count(parameters_parser, parameters_count as usize)(input)?; - Ok(( - input, - MethodParametersAttribute { - parameters_count, - parameters, - }, - )) -} - -pub fn parameters_parser(input: &[u8]) -> Result<(&[u8], ParameterAttribute), Err<&[u8]>> { - let (input, name_index) = be_u16(input)?; - let (input, access_flags) = be_u16(input)?; - Ok(( - input, - ParameterAttribute { - name_index, - access_flags, - }, - )) -} - -pub fn inner_classes_attribute_parser( - input: &[u8], -) -> Result<(&[u8], InnerClassesAttribute), Err<&[u8]>> { - let (input, number_of_classes) = be_u16(input)?; - let (input, classes) = count(inner_class_info_parser, number_of_classes as usize)(input)?; - let ret = ( - input, - InnerClassesAttribute { - number_of_classes, - classes, - }, - ); - - Ok(ret) -} - -pub fn inner_class_info_parser(input: &[u8]) -> Result<(&[u8], InnerClassInfo), Err<&[u8]>> { - let (input, inner_class_info_index) = be_u16(input)?; - let (input, outer_class_info_index) = be_u16(input)?; - let (input, inner_name_index) = be_u16(input)?; - let (input, inner_class_access_flags) = be_u16(input)?; - Ok(( - input, - InnerClassInfo { - inner_class_info_index, - outer_class_info_index, - inner_name_index, - inner_class_access_flags, - }, - )) -} - -pub fn enclosing_method_attribute_parser( - input: &[u8], -) -> Result<(&[u8], EnclosingMethodAttribute), Err<&[u8]>> { - let (input, class_index) = be_u16(input)?; - let (input, method_index) = be_u16(input)?; - Ok(( - input, - EnclosingMethodAttribute { - class_index, - method_index, - }, - )) -} - -pub fn signature_attribute_parser(input: &[u8]) -> Result<(&[u8], SignatureAttribute), Err<&[u8]>> { - let (input, signature_index) = be_u16(input)?; - Ok((input, SignatureAttribute { signature_index })) -} - -pub fn runtime_visible_annotations_attribute_parser( - input: &[u8], -) -> Result<(&[u8], RuntimeVisibleAnnotationsAttribute), Err<&[u8]>> { - let (input, num_annotations) = be_u16(input)?; - let (input, annotations) = count(annotation_parser, num_annotations as usize)(input)?; - Ok(( - input, - RuntimeVisibleAnnotationsAttribute { - num_annotations, - annotations, - }, - )) -} - -pub fn runtime_invisible_annotations_attribute_parser( - input: &[u8], -) -> Result<(&[u8], RuntimeInvisibleAnnotationsAttribute), Err<&[u8]>> { - let (input, num_annotations) = be_u16(input)?; - let (input, annotations) = count(annotation_parser, num_annotations as usize)(input)?; - Ok(( - input, - RuntimeInvisibleAnnotationsAttribute { - num_annotations, - annotations, - }, - )) -} - -pub fn runtime_visible_parameter_annotations_attribute_parser( - input: &[u8], -) -> Result<(&[u8], RuntimeVisibleParameterAnnotationsAttribute), Err<&[u8]>> { - let (input, num_parameters) = be_u8(input)?; - let (input, parameter_annotations) = count( - runtime_visible_annotations_attribute_parser, - num_parameters as usize, - )(input)?; - Ok(( - input, - RuntimeVisibleParameterAnnotationsAttribute { - num_parameters, - parameter_annotations, - }, - )) -} - -pub fn runtime_invisible_parameter_annotations_attribute_parser( - input: &[u8], -) -> Result<(&[u8], RuntimeInvisibleParameterAnnotationsAttribute), Err<&[u8]>> { - let (input, num_parameters) = be_u8(input)?; - let (input, parameter_annotations) = count( - runtime_invisible_annotations_attribute_parser, - num_parameters as usize, - )(input)?; - Ok(( - input, - RuntimeInvisibleParameterAnnotationsAttribute { - num_parameters, - parameter_annotations, - }, - )) -} -pub fn runtime_visible_type_annotations_attribute_parser( - input: &[u8], -) -> Result<(&[u8], RuntimeVisibleTypeAnnotationsAttribute), Err<&[u8]>> { - let (input, num_annotations) = be_u16(input)?; - let (input, type_annotations) = count(type_annotation_parser, num_annotations as usize)(input)?; - - Ok(( - input, - RuntimeVisibleTypeAnnotationsAttribute { - num_annotations, - type_annotations, - }, - )) -} - -pub fn runtime_invisible_type_annotations_attribute_parser( - input: &[u8], -) -> Result<(&[u8], RuntimeInvisibleTypeAnnotationsAttribute), Err<&[u8]>> { - let (input, num_annotations) = be_u16(input)?; - let (input, type_annotations) = count(type_annotation_parser, num_annotations as usize)(input)?; - - Ok(( - input, - RuntimeInvisibleTypeAnnotationsAttribute { - num_annotations, - type_annotations, - }, - )) -} - -pub fn type_annotation_parser(input: &[u8]) -> Result<(&[u8], TypeAnnotation), Err<&[u8]>> { - let (input, target_type) = be_u8(input)?; - let mut target_info: TargetInfo = TargetInfo::Empty; - match target_type { - 0x0 | 0x1 => { - let (_input, type_parameter_index) = be_u8(input)?; - target_info = TargetInfo::TypeParameter { - type_parameter_index, - }; - } - 0x10 => { - let (_input, supertype_index) = be_u16(input)?; - target_info = TargetInfo::SuperType { supertype_index }; - } - 0x11..=0x12 => { - let (input, type_parameter_index) = be_u8(input)?; - let (_input, bound_index) = be_u8(input)?; - target_info = TargetInfo::TypeParameterBound { - type_parameter_index, - bound_index, - } - } - 0x13..=0x15 => { - // Empty target_info - } - 0x16 => { - let (_input, formal_parameter_index) = be_u8(input)?; - target_info = TargetInfo::FormalParameter { - formal_parameter_index, - }; - } - 0x17 => { - let (_input, throws_type_index) = be_u16(input)?; - target_info = TargetInfo::Throws { throws_type_index }; - } - 0x40 | 0x41 => { - let (input, table_length) = be_u16(input)?; - let (_input, tables) = count( - local_variable_table_annotation_parser, - table_length as usize, - )(input)?; - target_info = TargetInfo::LocalVar { - table_length, - tables, - }; - } - 0x42 => { - let (_input, exception_table_index) = be_u16(input)?; - target_info = TargetInfo::Catch { - exception_table_index, - } - } - 0x43..=0x46 => { - let (_input, offset) = be_u16(input)?; - target_info = TargetInfo::Offset { offset } - } - 0x47..=0x4B => { - let (input, offset) = be_u16(input)?; - let (_input, type_argument_index) = be_u8(input)?; - target_info = TargetInfo::TypeArgument { - offset, - type_argument_index, - }; - } - _ => { - eprintln!( - "Parsing RuntimeVisibleTypeAnnotationsAttribute with target_type = {}", - target_type - ); - } - } - let (input, target_path) = target_path_parser(input)?; - let (input, type_index) = be_u16(input)?; - let (input, num_element_value_pairs) = be_u16(input)?; - let (input, element_value_pairs) = - count(element_value_pair_parser, num_element_value_pairs as usize)(input)?; - - Ok(( - input, - TypeAnnotation { - target_type, - target_info, - target_path, - type_index, - num_element_value_pairs, - element_value_pairs, - }, - )) -} - -fn target_path_parser(input: &[u8]) -> Result<(&[u8], TypePath), Err<&[u8]>> { - let (input, path_length) = be_u8(input)?; - let (input, paths) = count( - |input| { - let (input, type_path_kind) = be_u8(input)?; - let (input, type_argument_index) = be_u8(input)?; - Ok(( - input, - TypePathEntry { - type_path_kind, - type_argument_index, - }, - )) - }, - path_length as usize, - )(input)?; - Ok((input, TypePath { path_length, paths })) -} - -pub fn local_variable_table_annotation_parser( - input: &[u8], -) -> Result<(&[u8], LocalVarTableAnnotation), Err<&[u8]>> { - let (input, start_pc) = be_u16(input)?; - let (input, length) = be_u16(input)?; - let (input, index) = be_u16(input)?; - Ok(( - input, - LocalVarTableAnnotation { - start_pc, - length, - index, - }, - )) -} -fn annotation_parser(input: &[u8]) -> Result<(&[u8], RuntimeAnnotation), Err<&[u8]>> { - let (input, type_index) = be_u16(input)?; - let (input, num_element_value_pairs) = be_u16(input)?; - eprintln!( - "Parsing annotation with type index = {}, and {} element value pairs", - type_index, num_element_value_pairs - ); - let (input, element_value_pairs) = - count(element_value_pair_parser, num_element_value_pairs as usize)(input)?; - Ok(( - input, - RuntimeAnnotation { - type_index, - num_element_value_pairs, - element_value_pairs, - }, - )) -} - -fn element_value_pair_parser(input: &[u8]) -> Result<(&[u8], ElementValuePair), Err<&[u8]>> { - let (input, element_name_index) = be_u16(input)?; - let (input, value) = element_value_parser(input)?; - Ok(( - input, - ElementValuePair { - element_name_index, - value, - }, - )) -} - -fn array_value_parser(input: &[u8]) -> Result<(&[u8], ElementArrayValue), Err<&[u8]>> { - let (input, num_values) = be_u16(input)?; - let (input, values) = count(element_value_parser, num_values as usize)(input)?; - Ok((input, ElementArrayValue { num_values, values })) -} - -pub fn element_value_parser(input: &[u8]) -> Result<(&[u8], ElementValue), Err<&[u8]>> { - let (input, tag) = be_u8(input)?; - eprintln!("Element value parsing: tag = {}", tag as char); - - match tag as char { - 'B' | 'C' | 'I' | 'S' | 'Z' | 'D' | 'F' | 'J' | 's' => { - let (input, const_value_index) = be_u16(input)?; - eprintln!( - "Element value parsing: const_value_index = {}", - const_value_index - ); - Ok(( - input, - ElementValue::ConstValueIndex { - tag: tag as char, - value: const_value_index, - }, - )) - } - 'e' => { - let (input, enum_const_value) = enum_const_value_parser(input)?; - eprintln!( - "Element value parsing: enum_const_value = {:?}", - enum_const_value - ); - Ok((input, ElementValue::EnumConst(enum_const_value))) - } - 'c' => { - let (input, class_info_index) = be_u16(input)?; - eprintln!( - "Element value parsing: class_info_index = {}", - class_info_index - ); - Ok((input, ElementValue::ClassInfoIndex(class_info_index))) - } - '@' => { - let (input, annotation_value) = annotation_parser(input)?; - eprintln!( - "Element value parsing: annotation_value = {:?}", - annotation_value - ); - Ok((input, ElementValue::AnnotationValue(annotation_value))) - } - '[' => { - let (input, array_value) = array_value_parser(input)?; - eprintln!("Element value parsing: array_value = {:?}", array_value); - Ok((input, ElementValue::ElementArray(array_value))) - } - _ => Result::Err(Err::Error(error_position!(input, ErrorKind::NoneOf))), - } -} - -fn enum_const_value_parser(input: &[u8]) -> Result<(&[u8], EnumConstValue), Err<&[u8]>> { - let (input, type_name_index) = be_u16(input)?; - let (input, const_name_index) = be_u16(input)?; - Ok(( - input, - EnumConstValue { - type_name_index, - const_name_index, - }, - )) -} - -// not even really parsing ... -pub fn source_debug_extension_parser( - input: &[u8], -) -> Result<(&[u8], SourceDebugExtensionAttribute), Err<&[u8]>> { - let debug_extension = Vec::from(input); - Ok((input, SourceDebugExtensionAttribute { debug_extension })) -} - -pub fn line_number_table_attribute_parser( - input: &[u8], -) -> Result<(&[u8], LineNumberTable), Err<&[u8]>> { - let (input, line_number_table_length) = be_u16(input)?; - let (input, line_number_table) = count( - line_number_table_entry_parser, - line_number_table_length as usize, - )(input)?; - Ok(( - input, - LineNumberTable { - line_number_table_length, - line_number_table, - }, - )) -} - -pub fn line_number_table_entry_parser( - input: &[u8], -) -> Result<(&[u8], LineNumberTableEntry), Err<&[u8]>> { - let (input, start_pc) = be_u16(input)?; - let (input, line_number) = be_u16(input)?; - Ok(( - input, - LineNumberTableEntry { - start_pc, - line_number, - }, - )) -} - -fn same_frame_parser(input: &[u8], frame_type: u8) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - success(SameFrame { frame_type })(input) -} - -fn verification_type_parser(input: &[u8]) -> Result<(&[u8], VerificationTypeInfo), Err<&[u8]>> { - use self::VerificationTypeInfo::*; - let v = input[0]; - let new_input = &input[1..]; - match v { - 0 => Ok((new_input, Top)), - 1 => Ok((new_input, Integer)), - 2 => Ok((new_input, Float)), - 3 => Ok((new_input, Double)), - 4 => Ok((new_input, Long)), - 5 => Ok((new_input, Null)), - 6 => Ok((new_input, UninitializedThis)), - 7 => map(be_u16, |class| Object { class })(new_input), - 8 => map(be_u16, |offset| Uninitialized { offset })(new_input), - _ => Result::Err(Err::Error(error_position!(input, ErrorKind::NoneOf))), - } -} - -fn same_locals_1_stack_item_frame_parser( - input: &[u8], - frame_type: u8, -) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - let (input, stack) = verification_type_parser(input)?; - Ok((input, SameLocals1StackItemFrame { frame_type, stack })) -} - -fn same_locals_1_stack_item_frame_extended_parser( - input: &[u8], - frame_type: u8, -) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - let (input, offset_delta) = be_u16(input)?; - let (input, stack) = verification_type_parser(input)?; - Ok(( - input, - SameLocals1StackItemFrameExtended { - frame_type, - offset_delta, - stack, - }, - )) -} - -fn chop_frame_parser(input: &[u8], frame_type: u8) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - let (input, offset_delta) = be_u16(input)?; - Ok(( - input, - ChopFrame { - frame_type, - offset_delta, - }, - )) -} - -fn same_frame_extended_parser( - input: &[u8], - frame_type: u8, -) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - let (input, offset_delta) = be_u16(input)?; - Ok(( - input, - SameFrameExtended { - frame_type, - offset_delta, - }, - )) -} - -fn append_frame_parser(input: &[u8], frame_type: u8) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - let (input, offset_delta) = be_u16(input)?; - let (input, locals) = count(verification_type_parser, (frame_type - 251) as usize)(input)?; - Ok(( - input, - AppendFrame { - frame_type, - offset_delta, - locals, - }, - )) -} - -fn full_frame_parser(input: &[u8], frame_type: u8) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - let (input, offset_delta) = be_u16(input)?; - let (input, number_of_locals) = be_u16(input)?; - let (input, locals) = count(verification_type_parser, number_of_locals as usize)(input)?; - let (input, number_of_stack_items) = be_u16(input)?; - let (input, stack) = count(verification_type_parser, number_of_stack_items as usize)(input)?; - Ok(( - input, - FullFrame { - frame_type, - offset_delta, - number_of_locals, - locals, - number_of_stack_items, - stack, - }, - )) -} - -fn stack_frame_parser(input: &[u8], frame_type: u8) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - match frame_type { - 0..=63 => same_frame_parser(input, frame_type), - 64..=127 => same_locals_1_stack_item_frame_parser(input, frame_type), - 247 => same_locals_1_stack_item_frame_extended_parser(input, frame_type), - 248..=250 => chop_frame_parser(input, frame_type), - 251 => same_frame_extended_parser(input, frame_type), - 252..=254 => append_frame_parser(input, frame_type), - 255 => full_frame_parser(input, frame_type), - _ => Result::Err(Err::Error(error_position!(input, ErrorKind::NoneOf))), - } -} - -fn stack_map_frame_entry_parser(input: &[u8]) -> Result<(&[u8], StackMapFrame), Err<&[u8]>> { - let (input, frame_type) = be_u8(input)?; - let (input, stack_frame) = stack_frame_parser(input, frame_type)?; - Ok((input, stack_frame)) -} - -pub fn stack_map_table_attribute_parser( - input: &[u8], -) -> Result<(&[u8], StackMapTableAttribute), Err<&[u8]>> { - let (input, number_of_entries) = be_u16(input)?; - let (input, entries) = count(stack_map_frame_entry_parser, number_of_entries as usize)(input)?; - Ok(( - input, - StackMapTableAttribute { - number_of_entries, - entries, - }, - )) -} - -pub fn exceptions_attribute_parser( - input: &[u8], -) -> Result<(&[u8], ExceptionsAttribute), Err<&[u8]>> { - let (input, exception_table_length) = be_u16(input)?; - let (input, exception_table) = count(be_u16, exception_table_length as usize)(input)?; - Ok(( - input, - ExceptionsAttribute { - exception_table_length, - exception_table, - }, - )) -} - -pub fn constant_value_attribute_parser( - input: &[u8], -) -> Result<(&[u8], ConstantValueAttribute), Err<&[u8]>> { - let (input, constant_value_index) = be_u16(input)?; - Ok(( - input, - ConstantValueAttribute { - constant_value_index, - }, - )) -} - -fn bootstrap_method_parser(input: &[u8]) -> Result<(&[u8], BootstrapMethod), Err<&[u8]>> { - let (input, bootstrap_method_ref) = be_u16(input)?; - let (input, num_bootstrap_arguments) = be_u16(input)?; - let (input, bootstrap_arguments) = count(be_u16, num_bootstrap_arguments as usize)(input)?; - Ok(( - input, - BootstrapMethod { - bootstrap_method_ref, - num_bootstrap_arguments, - bootstrap_arguments, - }, - )) -} - -pub fn bootstrap_methods_attribute_parser( - input: &[u8], -) -> Result<(&[u8], BootstrapMethodsAttribute), Err<&[u8]>> { - let (input, num_bootstrap_methods) = be_u16(input)?; - let (input, bootstrap_methods) = - count(bootstrap_method_parser, num_bootstrap_methods as usize)(input)?; - Ok(( - input, - BootstrapMethodsAttribute { - num_bootstrap_methods, - bootstrap_methods, - }, - )) -} - -pub fn sourcefile_attribute_parser( - input: &[u8], -) -> Result<(&[u8], SourceFileAttribute), Err<&[u8]>> { - let (input, sourcefile_index) = be_u16(input)?; - Ok((input, SourceFileAttribute { sourcefile_index })) -} diff --git a/src/attribute_info/types.rs b/src/attribute_info/types.rs index 6febc66..c6e8067 100644 --- a/src/attribute_info/types.rs +++ b/src/attribute_info/types.rs @@ -1,4 +1,57 @@ -use binrw::binrw; +use std::io::{Cursor, Seek}; + +use binrw::{BinRead, BinResult, BinWrite, Endian, binrw, io::TakeSeekExt}; + +use crate::{ + InterpretInner, + code_attribute::{Instruction, LocalVariableTableAttribute, LocalVariableTypeTableAttribute}, + constant_info::{ConstantInfo, Utf8Constant}, +}; + +/// Custom parser for reading instructions from a code array. +/// +/// Replaces `binrw::helpers::until_eof` because `Instruction` uses +/// `return_unexpected_error` which produces `NoVariantMatch` on EOF. +/// `until_eof` checks `err.is_eof()` to detect end-of-stream, but +/// `NoVariantMatch.is_eof()` returns false, causing a spurious error. +/// +/// This parser also computes the correct per-instruction `address` +/// (offset within the code array) needed for tableswitch/lookupswitch +/// alignment padding. +#[binrw::parser(reader, endian)] +fn parse_code_instructions(code_start: u64) -> BinResult> { + let mut instructions = Vec::new(); + loop { + let pos = reader.stream_position()?; + let address = (pos - code_start) as u32; + match Instruction::read_options(reader, endian, binrw::args! { address: address }) { + Ok(instruction) => instructions.push(instruction), + Err(err) => { + reader.seek(std::io::SeekFrom::Start(pos))?; + let mut buf = [0u8; 1]; + if reader.read(&mut buf)? == 0 { + return Ok(instructions); + } + return Err(err); + } + } + } +} + +/// Custom writer for serializing instructions back into the code array. +/// +/// Tracks the running byte address so that tableswitch/lookupswitch +/// padding is computed correctly. +#[binrw::writer(writer, endian)] +fn write_code_instructions(code: &Vec) -> BinResult<()> { + let start = writer.stream_position()?; + for instruction in code { + let pos = writer.stream_position()?; + let address = (pos - start) as u32; + instruction.write_options(writer, endian, binrw::args! { address: address })?; + } + Ok(()) +} #[derive(Clone, Debug)] #[binrw] @@ -6,11 +59,254 @@ use binrw::binrw; pub struct AttributeInfo { pub attribute_name_index: u16, pub attribute_length: u32, - #[br(args { count: attribute_length.try_into().unwrap() })] + #[br(parse_with = validated_info_reader, args(attribute_length))] pub info: Vec, + #[brw(ignore)] + pub info_parsed: Option, +} + +/// Read `attribute_length` bytes for AttributeInfo::info, but first validate +/// that the declared length doesn't exceed remaining data (prevents OOM from +/// malicious u32 values in untrusted class files). +#[binrw::parser(reader)] +fn validated_info_reader(length: u32) -> BinResult> { + let len = length as usize; + if len > 0 { + let pos = reader.stream_position()?; + let end = reader.seek(std::io::SeekFrom::End(0))?; + reader.seek(std::io::SeekFrom::Start(pos))?; + let remaining = end.saturating_sub(pos) as usize; + if len > remaining { + return Err(binrw::Error::AssertFail { + pos, + message: format!( + "attribute_length {} exceeds remaining {} bytes", + len, remaining + ), + }); + } + } + let mut buf = vec![0u8; len]; + reader.read_exact(&mut buf)?; + Ok(buf) +} + +impl InterpretInner for AttributeInfo { + fn interpret_inner(&mut self, constant_pool: &Vec) { + if self.info_parsed.is_some() { + return; // already parsed + } + + if self.info.len() != self.attribute_length as usize { + return; // malformed: length mismatch, leave as raw bytes + } + + // Bounds-checked constant pool access + let idx = self.attribute_name_index.wrapping_sub(1) as usize; + let attr_name = match constant_pool.get(idx) { + Some(ConstantInfo::Utf8(Utf8Constant { utf8_string })) => utf8_string.clone(), + _ => return, // index out of bounds or not UTF-8, leave as raw bytes + }; + + /// Helper: try to read a binrw type from `info`, returning `None` on failure. + macro_rules! try_read { + ($ty:ty, $variant:ident) => { + <$ty>::read(&mut Cursor::new(&mut self.info)) + .ok() + .map(AttributeInfoVariant::$variant) + }; + (be: $ty:ty, $variant:ident) => { + <$ty>::read_be(&mut Cursor::new(&mut self.info)) + .ok() + .map(AttributeInfoVariant::$variant) + }; + } + + self.info_parsed = match attr_name.as_str() { + "ConstantValue" => try_read!(be: ConstantValueAttribute, ConstantValue), + "Code" => match CodeAttribute::read(&mut Cursor::new(&mut self.info)) { + Ok(mut code) => { + for attr in &mut code.attributes { + attr.interpret_inner(constant_pool); + } + Some(AttributeInfoVariant::Code(code)) + } + Err(_) => None, + }, + "StackMapTable" => try_read!(StackMapTableAttribute, StackMapTable), + "BootstrapMethods" => try_read!(BootstrapMethodsAttribute, BootstrapMethods), + "Exceptions" => try_read!(ExceptionsAttribute, Exceptions), + "InnerClasses" => try_read!(InnerClassesAttribute, InnerClasses), + "EnclosingMethod" => try_read!(EnclosingMethodAttribute, EnclosingMethod), + "Synthetic" => try_read!(SyntheticAttribute, Synthetic), + "Signature" => try_read!(SignatureAttribute, Signature), + "SourceFile" => try_read!(SourceFileAttribute, SourceFile), + "LineNumberTable" => try_read!(LineNumberTableAttribute, LineNumberTable), + "LocalVariableTable" => try_read!(LocalVariableTableAttribute, LocalVariableTable), + "LocalVariableTypeTable" => { + try_read!(LocalVariableTypeTableAttribute, LocalVariableTypeTable) + } + "SourceDebugExtension" => { + try_read!(SourceDebugExtensionAttribute, SourceDebugExtension) + } + "Deprecated" => try_read!(DeprecatedAttribute, Deprecated), + "RuntimeVisibleAnnotations" => try_read!( + RuntimeVisibleAnnotationsAttribute, + RuntimeVisibleAnnotations + ), + "RuntimeInvisibleAnnotations" => try_read!( + RuntimeInvisibleAnnotationsAttribute, + RuntimeInvisibleAnnotations + ), + "RuntimeVisibleParameterAnnotations" => try_read!( + RuntimeVisibleParameterAnnotationsAttribute, + RuntimeVisibleParameterAnnotations + ), + "RuntimeInvisibleParameterAnnotations" => try_read!( + RuntimeInvisibleParameterAnnotationsAttribute, + RuntimeInvisibleParameterAnnotations + ), + "RuntimeVisibleTypeAnnotations" => try_read!( + RuntimeVisibleTypeAnnotationsAttribute, + RuntimeVisibleTypeAnnotations + ), + "RuntimeInvisibleTypeAnnotations" => try_read!( + RuntimeInvisibleTypeAnnotationsAttribute, + RuntimeInvisibleTypeAnnotations + ), + "AnnotationDefault" => try_read!(AnnotationDefaultAttribute, AnnotationDefault), + "MethodParameters" => try_read!(MethodParametersAttribute, MethodParameters), + "Module" => try_read!(ModuleAttribute, Module), + "ModulePackages" => try_read!(ModulePackagesAttribute, ModulePackages), + "ModuleMainClass" => try_read!(ModuleMainClassAttribute, ModuleMainClass), + "NestHost" => try_read!(NestHostAttribute, NestHost), + "NestMembers" => try_read!(NestMembersAttribute, NestMembers), + "Record" => match RecordAttribute::read(&mut Cursor::new(&mut self.info)) { + Ok(mut record) => { + for component in &mut record.components { + for attr in &mut component.attributes { + attr.interpret_inner(constant_pool); + } + } + Some(AttributeInfoVariant::Record(record)) + } + Err(_) => None, + }, + "PermittedSubclasses" => try_read!(PermittedSubclassesAttribute, PermittedSubclasses), + unhandled => Some(AttributeInfoVariant::Unknown(String::from(unhandled))), + }; + } +} + +impl AttributeInfo { + /// Serializes `info_parsed` back into `info` bytes and updates `attribute_length`. + /// + /// Call this after modifying `info_parsed` to keep the raw bytes in sync. + pub fn sync_from_parsed(&mut self) -> BinResult<()> { + let new_info = match &mut self.info_parsed { + Some(parsed) => { + let mut cursor = Cursor::new(Vec::new()); + match parsed { + AttributeInfoVariant::Code(v) => { + v.sync_lengths()?; + for attr in &mut v.attributes { + attr.sync_from_parsed()?; + } + v.write(&mut cursor)?; + } + AttributeInfoVariant::ConstantValue(v) => v.write(&mut cursor)?, + AttributeInfoVariant::StackMapTable(v) => v.write(&mut cursor)?, + AttributeInfoVariant::Exceptions(v) => v.write(&mut cursor)?, + AttributeInfoVariant::InnerClasses(v) => v.write(&mut cursor)?, + AttributeInfoVariant::EnclosingMethod(v) => v.write(&mut cursor)?, + AttributeInfoVariant::Synthetic(v) => v.write(&mut cursor)?, + AttributeInfoVariant::Signature(v) => v.write(&mut cursor)?, + AttributeInfoVariant::SourceFile(v) => v.write(&mut cursor)?, + AttributeInfoVariant::SourceDebugExtension(v) => v.write(&mut cursor)?, + AttributeInfoVariant::LineNumberTable(v) => v.write(&mut cursor)?, + AttributeInfoVariant::LocalVariableTable(v) => v.write(&mut cursor)?, + AttributeInfoVariant::LocalVariableTypeTable(v) => v.write(&mut cursor)?, + AttributeInfoVariant::Deprecated(v) => v.write(&mut cursor)?, + AttributeInfoVariant::RuntimeVisibleAnnotations(v) => v.write(&mut cursor)?, + AttributeInfoVariant::RuntimeInvisibleAnnotations(v) => v.write(&mut cursor)?, + AttributeInfoVariant::RuntimeVisibleParameterAnnotations(v) => { + v.write(&mut cursor)? + } + AttributeInfoVariant::RuntimeInvisibleParameterAnnotations(v) => { + v.write(&mut cursor)? + } + AttributeInfoVariant::RuntimeVisibleTypeAnnotations(v) => { + v.write(&mut cursor)? + } + AttributeInfoVariant::RuntimeInvisibleTypeAnnotations(v) => { + v.write(&mut cursor)? + } + AttributeInfoVariant::AnnotationDefault(v) => v.write(&mut cursor)?, + AttributeInfoVariant::BootstrapMethods(v) => v.write(&mut cursor)?, + AttributeInfoVariant::MethodParameters(v) => v.write(&mut cursor)?, + AttributeInfoVariant::Module(v) => v.write(&mut cursor)?, + AttributeInfoVariant::ModulePackages(v) => v.write(&mut cursor)?, + AttributeInfoVariant::ModuleMainClass(v) => v.write(&mut cursor)?, + AttributeInfoVariant::NestHost(v) => v.write(&mut cursor)?, + AttributeInfoVariant::NestMembers(v) => v.write(&mut cursor)?, + AttributeInfoVariant::Record(v) => { + for component in &mut v.components { + for attr in &mut component.attributes { + attr.sync_from_parsed()?; + } + } + v.write(&mut cursor)?; + } + AttributeInfoVariant::PermittedSubclasses(v) => v.write(&mut cursor)?, + AttributeInfoVariant::Unknown(_) => return Ok(()), + } + cursor.into_inner() + } + None => return Ok(()), + }; + self.attribute_length = new_info.len() as u32; + self.info = new_info; + Ok(()) + } } #[derive(Clone, Debug)] +pub enum AttributeInfoVariant { + ConstantValue(ConstantValueAttribute), + Code(CodeAttribute), + StackMapTable(StackMapTableAttribute), + Exceptions(ExceptionsAttribute), + InnerClasses(InnerClassesAttribute), + EnclosingMethod(EnclosingMethodAttribute), + Synthetic(SyntheticAttribute), + Signature(SignatureAttribute), + SourceFile(SourceFileAttribute), + SourceDebugExtension(SourceDebugExtensionAttribute), + LineNumberTable(LineNumberTableAttribute), + LocalVariableTable(LocalVariableTableAttribute), + LocalVariableTypeTable(LocalVariableTypeTableAttribute), + Deprecated(DeprecatedAttribute), + RuntimeVisibleAnnotations(RuntimeVisibleAnnotationsAttribute), + RuntimeInvisibleAnnotations(RuntimeInvisibleAnnotationsAttribute), + RuntimeVisibleParameterAnnotations(RuntimeVisibleParameterAnnotationsAttribute), + RuntimeInvisibleParameterAnnotations(RuntimeInvisibleParameterAnnotationsAttribute), + RuntimeVisibleTypeAnnotations(RuntimeVisibleTypeAnnotationsAttribute), + RuntimeInvisibleTypeAnnotations(RuntimeInvisibleTypeAnnotationsAttribute), + AnnotationDefault(AnnotationDefaultAttribute), + BootstrapMethods(BootstrapMethodsAttribute), + MethodParameters(MethodParametersAttribute), + Module(ModuleAttribute), + ModulePackages(ModulePackagesAttribute), + ModuleMainClass(ModuleMainClassAttribute), + NestHost(NestHostAttribute), + NestMembers(NestMembersAttribute), + Record(RecordAttribute), + PermittedSubclasses(PermittedSubclassesAttribute), + Unknown(String), +} + +#[derive(Clone, Debug)] +#[binrw] pub struct ExceptionEntry { pub start_pc: u16, pub end_pc: u16, @@ -19,36 +315,119 @@ pub struct ExceptionEntry { } #[derive(Clone, Debug)] +#[binrw] +#[brw(big, stream = s)] pub struct CodeAttribute { pub max_stack: u16, pub max_locals: u16, pub code_length: u32, - pub code: Vec, + #[br(map_stream = |s| s.take_seek(code_length as u64), parse_with = parse_code_instructions, args(s.stream_position()?))] + #[bw(write_with = write_code_instructions)] + pub code: Vec, pub exception_table_length: u16, + #[br(count = exception_table_length)] pub exception_table: Vec, pub attributes_count: u16, + #[br(count = attributes_count)] pub attributes: Vec, } +impl CodeAttribute { + /// Recalculates `code_length`, `exception_table_length`, and `attributes_count` + /// from actual vector contents. Call this after modifying instructions or other + /// code attribute internals. + pub fn sync_lengths(&mut self) -> BinResult<()> { + let mut buf = Cursor::new(Vec::new()); + for instruction in &self.code { + let pos = buf.stream_position()?; + let address = pos as u32; + instruction.write_options(&mut buf, Endian::Big, binrw::args! { address: address })?; + } + self.code_length = buf.into_inner().len() as u32; + self.exception_table_length = self.exception_table.len() as u16; + self.attributes_count = self.attributes.len() as u16; + Ok(()) + } + + /// Find the first instruction matching a predicate. Returns `(index, &Instruction)`. + pub fn find_instruction(&self, predicate: F) -> Option<(usize, &Instruction)> + where + F: Fn(&Instruction) -> bool, + { + self.code + .iter() + .enumerate() + .find(|(_, instr)| predicate(instr)) + } + + /// Find all instructions matching a predicate. Returns `Vec<(index, &Instruction)>`. + pub fn find_instructions(&self, predicate: F) -> Vec<(usize, &Instruction)> + where + F: Fn(&Instruction) -> bool, + { + self.code + .iter() + .enumerate() + .filter(|(_, instr)| predicate(instr)) + .collect() + } + + /// Replace the instruction at `index`. + pub fn replace_instruction(&mut self, index: usize, replacement: Instruction) { + self.code[index] = replacement; + } + + /// Replace a range of instructions with the exact number of Nop instructions + /// needed to preserve `code_length`. Handles variable-length instructions + /// (tableswitch, lookupswitch) by serializing to compute byte sizes. + pub fn nop_out(&mut self, range: std::ops::Range) -> BinResult<()> { + let mut buf = Cursor::new(Vec::new()); + // Serialize instructions before range to find byte offset + for instr in &self.code[..range.start] { + let address = buf.stream_position()? as u32; + instr.write_options(&mut buf, Endian::Big, binrw::args! { address })?; + } + // Serialize instructions in range to find byte count + let range_start_pos = buf.stream_position()?; + for instr in &self.code[range.clone()] { + let address = buf.stream_position()? as u32; + instr.write_options(&mut buf, Endian::Big, binrw::args! { address })?; + } + let byte_count = (buf.stream_position()? - range_start_pos) as usize; + self.code.splice(range, vec![Instruction::Nop; byte_count]); + Ok(()) + } +} + #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct MethodParametersAttribute { pub parameters_count: u8, + #[br(count = parameters_count)] pub parameters: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct ParameterAttribute { pub name_index: u16, pub access_flags: u16, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct InnerClassesAttribute { pub number_of_classes: u16, + #[br(count = number_of_classes)] pub classes: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct InnerClassInfo { pub inner_class_info_index: u16, pub outer_class_info_index: u16, @@ -73,96 +452,130 @@ bitflags! { } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct EnclosingMethodAttribute { pub class_index: u16, pub method_index: u16, } +// in all reality this struct isn't required b/c it's zero sized +// "Deprecated" is a marker attribute +#[derive(Clone, Debug)] +#[binrw] +pub struct DeprecatedAttribute {} + // in all reality this struct isn't required b/c it's zero sized // "Synthetic" is a marker attribute #[derive(Clone, Debug)] +#[binrw] pub struct SyntheticAttribute {} #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct SignatureAttribute { pub signature_index: u16, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct RuntimeVisibleAnnotationsAttribute { pub num_annotations: u16, + #[br(count = num_annotations)] pub annotations: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct RuntimeInvisibleAnnotationsAttribute { pub num_annotations: u16, + #[br(count = num_annotations)] pub annotations: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct RuntimeVisibleParameterAnnotationsAttribute { pub num_parameters: u8, + #[br(count = num_parameters)] pub parameter_annotations: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct RuntimeInvisibleParameterAnnotationsAttribute { pub num_parameters: u8, + #[br(count = num_parameters)] pub parameter_annotations: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct RuntimeVisibleTypeAnnotationsAttribute { pub num_annotations: u16, + #[br(count = num_annotations)] pub type_annotations: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct RuntimeInvisibleTypeAnnotationsAttribute { pub num_annotations: u16, + #[br(count = num_annotations)] pub type_annotations: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct TypeAnnotation { pub target_type: u8, + #[br(args(target_type))] pub target_info: TargetInfo, pub target_path: TypePath, pub type_index: u16, pub num_element_value_pairs: u16, + #[br(count = num_element_value_pairs)] pub element_value_pairs: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[br(import(target_type: u8))] pub enum TargetInfo { - TypeParameter { - type_parameter_index: u8, - }, - SuperType { - supertype_index: u16, - }, + #[br(pre_assert(target_type == 0x00 || target_type == 0x01))] + TypeParameter { type_parameter_index: u8 }, + #[br(pre_assert(target_type == 0x10))] + SuperType { supertype_index: u16 }, + #[br(pre_assert(target_type == 0x11 || target_type == 0x12))] TypeParameterBound { type_parameter_index: u8, bound_index: u8, }, + #[br(pre_assert((0x13..=0x15).contains(&target_type)))] Empty, - FormalParameter { - formal_parameter_index: u8, - }, - Throws { - throws_type_index: u16, - }, + #[br(pre_assert(target_type == 0x16))] + FormalParameter { formal_parameter_index: u8 }, + #[br(pre_assert(target_type == 0x17))] + Throws { throws_type_index: u16 }, + #[br(pre_assert(target_type == 0x40 || target_type == 0x41))] LocalVar { table_length: u16, + #[br(count = table_length)] tables: Vec, }, - Catch { - exception_table_index: u16, - }, - Offset { - offset: u16, - }, + #[br(pre_assert(target_type == 0x42))] + Catch { exception_table_index: u16 }, + #[br(pre_assert((0x43..=0x46).contains(&target_type)))] + Offset { offset: u16 }, + #[br(pre_assert((0x47..=0x4B).contains(&target_type)))] TypeArgument { offset: u16, type_argument_index: u8, @@ -170,18 +583,22 @@ pub enum TargetInfo { } #[derive(Clone, Debug)] +#[binrw] pub struct TypePath { pub path_length: u8, + #[br(count = path_length)] pub paths: Vec, } #[derive(Clone, Debug)] +#[binrw] pub struct TypePathEntry { pub type_path_kind: u8, pub type_argument_index: u8, } #[derive(Clone, Debug)] +#[binrw] pub struct LocalVarTableAnnotation { pub start_pc: u16, pub length: u16, @@ -189,24 +606,43 @@ pub struct LocalVarTableAnnotation { } #[derive(Clone, Debug)] +#[binrw] pub struct RuntimeAnnotation { pub type_index: u16, pub num_element_value_pairs: u16, + #[br(count = num_element_value_pairs)] pub element_value_pairs: Vec, } -pub type DefaultAnnotation = ElementValue; +pub type AnnotationDefaultAttribute = ElementValue; #[derive(Clone, Debug)] +#[binrw] pub struct ElementValuePair { pub element_name_index: u16, pub value: ElementValue, } +#[binrw::parser(reader)] +fn custom_char_parser() -> BinResult { + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf)?; + let c = u8::from_be_bytes(buf) as char; + Ok(c) +} + +#[binrw::writer(writer)] +pub fn custom_char_writer(c: &char) -> BinResult<()> { + writer.write_all(c.to_string().as_bytes())?; + Ok(()) +} + #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub enum ElementValue { // pub tag: u8, - ConstValueIndex { tag: char, value: u16 }, + ConstValueIndex(ConstValueIndexValue), EnumConst(EnumConstValue), ClassInfoIndex(u16), AnnotationValue(RuntimeAnnotation), @@ -214,51 +650,80 @@ pub enum ElementValue { } #[derive(Clone, Debug)] +#[binrw] +pub struct ConstValueIndexValue { + #[br(parse_with = custom_char_parser)] + #[bw(write_with = custom_char_writer)] + pub tag: char, + pub value: u16, +} + +#[derive(Clone, Debug)] +#[binrw] pub struct ElementArrayValue { pub num_values: u16, + #[br(count = num_values)] pub values: Vec, } #[derive(Clone, Debug)] +#[binrw] pub struct EnumConstValue { pub type_name_index: u16, pub const_name_index: u16, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct SourceDebugExtensionAttribute { // Per the spec: // The debug_extension array holds extended debugging information which has no // semantic effect on the Java Virtual Machine. The information is represented // using a modified UTF-8 string with no terminating zero byte. - pub debug_extension: Vec, + // pub debug_extension: Vec, } #[derive(Clone, Debug)] -pub struct LineNumberTable { +#[binrw] +#[brw(big)] +pub struct LineNumberTableAttribute { pub line_number_table_length: u16, + #[br(count = line_number_table_length)] pub line_number_table: Vec, } #[derive(Clone, Debug)] +#[binrw] pub struct LineNumberTableEntry { pub start_pc: u16, pub line_number: u16, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub enum VerificationTypeInfo { + #[brw(magic = 0u8)] Top, + #[brw(magic = 1u8)] Integer, + #[brw(magic = 2u8)] Float, + #[brw(magic = 3u8)] Double, + #[brw(magic = 4u8)] Long, + #[brw(magic = 5u8)] Null, + #[brw(magic = 6u8)] UninitializedThis, + #[brw(magic = 7u8)] Object { /// An index into the constant pool for the class of the object class: u16, }, + #[brw(magic = 8u8)] Uninitialized { /// Offset into associated code array of a new instruction /// that created the object being stored here. @@ -267,69 +732,104 @@ pub enum VerificationTypeInfo { } #[derive(Clone, Debug)] -pub enum StackMapFrame { +#[binrw] +#[brw(big)] +pub struct StackMapFrame { + pub frame_type: u8, + #[br(args(frame_type))] + pub inner: StackMapFrameInner, +} + +#[derive(Clone, Debug)] +#[binrw] +#[br(import(frame_type: u8))] +pub enum StackMapFrameInner { + #[br(pre_assert((0..=63).contains(&frame_type)))] SameFrame { - frame_type: u8, + //frame_type: u8, }, + #[br(pre_assert((64..=127).contains(&frame_type)))] SameLocals1StackItemFrame { - frame_type: u8, + //frame_type: u8, stack: VerificationTypeInfo, }, + #[br(pre_assert(frame_type == 247))] SameLocals1StackItemFrameExtended { - frame_type: u8, + //frame_type: u8, offset_delta: u16, stack: VerificationTypeInfo, }, + #[br(pre_assert((248..=250).contains(&frame_type)))] ChopFrame { - frame_type: u8, + //frame_type: u8, offset_delta: u16, }, + #[br(pre_assert(frame_type == 251))] SameFrameExtended { - frame_type: u8, + //frame_type: u8, offset_delta: u16, }, + #[br(pre_assert((252..=254).contains(&frame_type)))] AppendFrame { - frame_type: u8, + //frame_type: u8, offset_delta: u16, + #[br(count = frame_type - 251)] locals: Vec, }, + #[br(pre_assert(frame_type == 255))] FullFrame { - frame_type: u8, + //frame_type: u8, offset_delta: u16, number_of_locals: u16, + #[br(count = number_of_locals)] locals: Vec, number_of_stack_items: u16, + #[br(count = number_of_stack_items)] stack: Vec, }, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct StackMapTableAttribute { pub number_of_entries: u16, + #[br(count = number_of_entries)] pub entries: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct ExceptionsAttribute { pub exception_table_length: u16, + #[br(count = exception_table_length)] pub exception_table: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct ConstantValueAttribute { pub constant_value_index: u16, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct BootstrapMethod { pub bootstrap_method_ref: u16, pub num_bootstrap_arguments: u16, + #[br(count = num_bootstrap_arguments)] pub bootstrap_arguments: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct BootstrapMethodsAttribute { pub num_bootstrap_methods: u16, + #[br(count = num_bootstrap_methods)] pub bootstrap_methods: Vec, } @@ -338,8 +838,136 @@ pub struct BootstrapMethodsAttribute { /// There may be at most one SourceFile attribute in the attributes table of a ClassFile structure. /// [see more](https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.7.10) #[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] pub struct SourceFileAttribute { /// The value of the sourcefile_index item must be a valid index into the constant_pool table. /// The constant_pool entry at that index must be a CONSTANT_Utf8_info structure representing a string. pub sourcefile_index: u16, } + +#[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct ModuleAttribute { + pub module_name_index: u16, + pub module_flags: u16, + pub module_version_index: u16, + pub requires_count: u16, + #[br(count = requires_count)] + pub requires: Vec, + pub exports_count: u16, + #[br(count = exports_count)] + pub exports: Vec, + pub opens_count: u16, + #[br(count = opens_count)] + pub opens: Vec, + pub uses_count: u16, + #[br(count = uses_count)] + pub uses: Vec, + pub provides_count: u16, + #[br(count = provides_count)] + pub provides: Vec, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct ModuleRequiresAttribute { + pub requires_index: u16, + pub requires_flags: u16, + pub requires_version_index: u16, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct ModuleExportsAttribute { + pub exports_index: u16, + pub exports_flags: u16, + pub exports_to_count: u16, + #[br(count = exports_to_count)] + pub exports_to_index: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct ModuleOpensAttribute { + pub opens_index: u16, + pub opens_flags: u16, + pub opens_to_count: u16, + #[br(count = opens_to_count)] + pub opens_to_index: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct ModuleProvidesAttribute { + pub provides_index: u16, + pub provides_with_count: u16, + #[br(count = provides_with_count)] + pub provides_with_index: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct ModulePackagesAttribute { + pub package_count: u16, + #[br(count = package_count)] + pub package_index: Vec, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct ModuleMainClassAttribute { + pub main_class_index: u16, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct NestHostAttribute { + pub host_class_index: u16, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct NestMembersAttribute { + pub number_of_classes: u16, + #[br(count = number_of_classes)] + pub classes: Vec, +} + +#[derive(Clone, Debug)] +#[binrw] +#[brw(big)] +pub struct RecordAttribute { + pub components_count: u16, + #[br(count = components_count)] + pub components: Vec, +} + +#[derive(Clone, Debug)] +#[binrw] +#[brw(big)] +pub struct RecordComponentInfo { + pub name_index: u16, + pub descriptor_index: u16, + pub attributes_count: u16, + #[br(count = attributes_count)] + pub attributes: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[brw(big)] +pub struct PermittedSubclassesAttribute { + pub number_of_classes: u16, + #[br(count = number_of_classes)] + pub classes: Vec, +} diff --git a/src/code_attribute/mod.rs b/src/code_attribute/mod.rs index cae52dd..6108ca4 100644 --- a/src/code_attribute/mod.rs +++ b/src/code_attribute/mod.rs @@ -1,9 +1,3 @@ -mod parser; mod types; pub use self::types::*; - -pub use self::parser::code_parser; -pub use self::parser::instruction_parser; -pub use self::parser::local_variable_table_parser; -pub use self::parser::local_variable_type_table_parser; diff --git a/src/code_attribute/parser.rs b/src/code_attribute/parser.rs deleted file mode 100644 index f706f8d..0000000 --- a/src/code_attribute/parser.rs +++ /dev/null @@ -1,373 +0,0 @@ -use crate::code_attribute::types::Instruction; -use nom::{ - Err as BaseErr, IResult, Offset, - bytes::complete::{tag, take}, - combinator::{complete, fail, map, success}, - error::Error, - multi::{count, many0}, - number::complete::{be_i8, be_i16, be_i32, be_u8, be_u16, be_u32}, - sequence::{pair, preceded, tuple}, -}; - -use super::{ - LocalVariableTableAttribute, LocalVariableTableItem, LocalVariableTypeTableAttribute, - LocalVariableTypeTableItem, -}; -type Err = BaseErr>; - -fn offset<'a>(remaining: &'a [u8], input: &[u8]) -> IResult<&'a [u8], usize> { - Ok((remaining, input.offset(remaining))) -} - -fn align(address: usize) -> impl Fn(&[u8]) -> IResult<&[u8], &[u8]> { - move |input: &[u8]| take((4 - address % 4) % 4)(input) -} - -fn lookupswitch_parser(input: &[u8]) -> IResult<&[u8], Instruction> { - // This function provides type annotations required by rustc. - fn each_pair(input: &[u8]) -> IResult<&[u8], (i32, i32)> { - let (input, lookup) = be_i32(input)?; - let (input, offset) = be_i32(input)?; - Ok((input, (lookup, offset))) - } - let (input, default) = be_i32(input)?; - let (input, npairs) = be_u32(input)?; - let (input, pairs) = count(each_pair, npairs as usize)(input)?; - Ok((input, Instruction::Lookupswitch { default, pairs })) -} - -fn tableswitch_parser(input: &[u8]) -> IResult<&[u8], Instruction> { - let (input, default) = be_i32(input)?; - let (input, low) = be_i32(input)?; - let (input, high) = be_i32(input)?; - let (input, offsets) = count(be_i32, (high - low + 1) as usize)(input)?; - Ok(( - input, - Instruction::Tableswitch { - default, - low, - high, - offsets, - }, - )) -} - -pub fn code_parser(outer_input: &[u8]) -> IResult<&[u8], Vec<(usize, Instruction)>> { - many0(complete(|input| { - let (input, address) = offset(input, outer_input)?; - let (input, instruction) = instruction_parser(input, address)?; - Ok((input, (address, instruction))) - }))(outer_input) -} - -pub fn instruction_parser(input: &[u8], address: usize) -> IResult<&[u8], Instruction> { - let (input, b0) = be_u8(input)?; - let (input, instruction) = match b0 { - 0x32 => success(Instruction::Aaload)(input)?, - 0x53 => success(Instruction::Aastore)(input)?, - 0x01 => success(Instruction::Aconstnull)(input)?, - 0x19 => map(be_u8, Instruction::Aload)(input)?, - 0x2a => success(Instruction::Aload0)(input)?, - 0x2b => success(Instruction::Aload1)(input)?, - 0x2c => success(Instruction::Aload2)(input)?, - 0x2d => success(Instruction::Aload3)(input)?, - 0xbd => map(be_u16, Instruction::Anewarray)(input)?, - 0xb0 => success(Instruction::Areturn)(input)?, - 0xbe => success(Instruction::Arraylength)(input)?, - 0x3a => map(be_u8, Instruction::Astore)(input)?, - 0x4b => success(Instruction::Astore0)(input)?, - 0x4c => success(Instruction::Astore1)(input)?, - 0x4d => success(Instruction::Astore2)(input)?, - 0x4e => success(Instruction::Astore3)(input)?, - 0xbf => success(Instruction::Athrow)(input)?, - 0x33 => success(Instruction::Baload)(input)?, - 0x54 => success(Instruction::Bastore)(input)?, - 0x10 => map(be_i8, Instruction::Bipush)(input)?, - 0x34 => success(Instruction::Caload)(input)?, - 0x55 => success(Instruction::Castore)(input)?, - 0xc0 => map(be_u16, Instruction::Checkcast)(input)?, - 0x90 => success(Instruction::D2f)(input)?, - 0x8e => success(Instruction::D2i)(input)?, - 0x8f => success(Instruction::D2l)(input)?, - 0x63 => success(Instruction::Dadd)(input)?, - 0x31 => success(Instruction::Daload)(input)?, - 0x52 => success(Instruction::Dastore)(input)?, - 0x98 => success(Instruction::Dcmpg)(input)?, - 0x97 => success(Instruction::Dcmpl)(input)?, - 0x0e => success(Instruction::Dconst0)(input)?, - 0x0f => success(Instruction::Dconst1)(input)?, - 0x6f => success(Instruction::Ddiv)(input)?, - 0x18 => map(be_u8, Instruction::Dload)(input)?, - 0x26 => success(Instruction::Dload0)(input)?, - 0x27 => success(Instruction::Dload1)(input)?, - 0x28 => success(Instruction::Dload2)(input)?, - 0x29 => success(Instruction::Dload3)(input)?, - 0x6b => success(Instruction::Dmul)(input)?, - 0x77 => success(Instruction::Dneg)(input)?, - 0x73 => success(Instruction::Drem)(input)?, - 0xaf => success(Instruction::Dreturn)(input)?, - 0x39 => map(be_u8, Instruction::Dstore)(input)?, - 0x47 => success(Instruction::Dstore0)(input)?, - 0x48 => success(Instruction::Dstore1)(input)?, - 0x49 => success(Instruction::Dstore2)(input)?, - 0x4a => success(Instruction::Dstore3)(input)?, - 0x67 => success(Instruction::Dsub)(input)?, - 0x59 => success(Instruction::Dup)(input)?, - 0x5a => success(Instruction::Dupx1)(input)?, - 0x5b => success(Instruction::Dupx2)(input)?, - 0x5c => success(Instruction::Dup2)(input)?, - 0x5d => success(Instruction::Dup2x1)(input)?, - 0x5e => success(Instruction::Dup2x2)(input)?, - 0x8d => success(Instruction::F2d)(input)?, - 0x8b => success(Instruction::F2i)(input)?, - 0x8c => success(Instruction::F2l)(input)?, - 0x62 => success(Instruction::Fadd)(input)?, - 0x30 => success(Instruction::Faload)(input)?, - 0x51 => success(Instruction::Fastore)(input)?, - 0x96 => success(Instruction::Fcmpg)(input)?, - 0x95 => success(Instruction::Fcmpl)(input)?, - 0x0b => success(Instruction::Fconst0)(input)?, - 0x0c => success(Instruction::Fconst1)(input)?, - 0x0d => success(Instruction::Fconst2)(input)?, - 0x6e => success(Instruction::Fdiv)(input)?, - 0x17 => map(be_u8, Instruction::Fload)(input)?, - 0x22 => success(Instruction::Fload0)(input)?, - 0x23 => success(Instruction::Fload1)(input)?, - 0x24 => success(Instruction::Fload2)(input)?, - 0x25 => success(Instruction::Fload3)(input)?, - 0x6a => success(Instruction::Fmul)(input)?, - 0x76 => success(Instruction::Fneg)(input)?, - 0x72 => success(Instruction::Frem)(input)?, - 0xae => success(Instruction::Freturn)(input)?, - 0x38 => map(be_u8, Instruction::Fstore)(input)?, - 0x43 => success(Instruction::Fstore0)(input)?, - 0x44 => success(Instruction::Fstore1)(input)?, - 0x45 => success(Instruction::Fstore2)(input)?, - 0x46 => success(Instruction::Fstore3)(input)?, - 0x66 => success(Instruction::Fsub)(input)?, - 0xb4 => map(be_u16, Instruction::Getfield)(input)?, - 0xb2 => map(be_u16, Instruction::Getstatic)(input)?, - 0xa7 => map(be_i16, Instruction::Goto)(input)?, - 0xc8 => map(be_i32, Instruction::GotoW)(input)?, - 0x91 => success(Instruction::I2b)(input)?, - 0x92 => success(Instruction::I2c)(input)?, - 0x87 => success(Instruction::I2d)(input)?, - 0x86 => success(Instruction::I2f)(input)?, - 0x85 => success(Instruction::I2l)(input)?, - 0x93 => success(Instruction::I2s)(input)?, - 0x60 => success(Instruction::Iadd)(input)?, - 0x2e => success(Instruction::Iaload)(input)?, - 0x7e => success(Instruction::Iand)(input)?, - 0x4f => success(Instruction::Iastore)(input)?, - 0x02 => success(Instruction::Iconstm1)(input)?, - 0x03 => success(Instruction::Iconst0)(input)?, - 0x04 => success(Instruction::Iconst1)(input)?, - 0x05 => success(Instruction::Iconst2)(input)?, - 0x06 => success(Instruction::Iconst3)(input)?, - 0x07 => success(Instruction::Iconst4)(input)?, - 0x08 => success(Instruction::Iconst5)(input)?, - 0x6c => success(Instruction::Idiv)(input)?, - 0xa5 => map(be_i16, Instruction::IfAcmpeq)(input)?, - 0xa6 => map(be_i16, Instruction::IfAcmpne)(input)?, - 0x9f => map(be_i16, Instruction::IfIcmpeq)(input)?, - 0xa0 => map(be_i16, Instruction::IfIcmpne)(input)?, - 0xa1 => map(be_i16, Instruction::IfIcmplt)(input)?, - 0xa2 => map(be_i16, Instruction::IfIcmpge)(input)?, - 0xa3 => map(be_i16, Instruction::IfIcmpgt)(input)?, - 0xa4 => map(be_i16, Instruction::IfIcmple)(input)?, - 0x99 => map(be_i16, Instruction::Ifeq)(input)?, - 0x9a => map(be_i16, Instruction::Ifne)(input)?, - 0x9b => map(be_i16, Instruction::Iflt)(input)?, - 0x9c => map(be_i16, Instruction::Ifge)(input)?, - 0x9d => map(be_i16, Instruction::Ifgt)(input)?, - 0x9e => map(be_i16, Instruction::Ifle)(input)?, - 0xc7 => map(be_i16, Instruction::Ifnonnull)(input)?, - 0xc6 => map(be_i16, Instruction::Ifnull)(input)?, - 0x84 => map(pair(be_u8, be_i8), |(index, value)| Instruction::Iinc { - index, - value, - })(input)?, - 0x15 => map(be_u8, Instruction::Iload)(input)?, - 0x1a => success(Instruction::Iload0)(input)?, - 0x1b => success(Instruction::Iload1)(input)?, - 0x1c => success(Instruction::Iload2)(input)?, - 0x1d => success(Instruction::Iload3)(input)?, - 0x68 => success(Instruction::Imul)(input)?, - 0x74 => success(Instruction::Ineg)(input)?, - 0xc1 => map(be_u16, Instruction::Instanceof)(input)?, - 0xba => map(pair(be_u16, tag(&[0, 0])), |(index, _)| { - Instruction::Invokedynamic(index) - })(input)?, - 0xb9 => map(tuple((be_u16, be_u8, tag(&[0]))), |(index, count, _)| { - Instruction::Invokeinterface { index, count } - })(input)?, - 0xb7 => map(be_u16, Instruction::Invokespecial)(input)?, - 0xb8 => map(be_u16, Instruction::Invokestatic)(input)?, - 0xb6 => map(be_u16, Instruction::Invokevirtual)(input)?, - 0x80 => success(Instruction::Ior)(input)?, - 0x70 => success(Instruction::Irem)(input)?, - 0xac => success(Instruction::Ireturn)(input)?, - 0x78 => success(Instruction::Ishl)(input)?, - 0x7a => success(Instruction::Ishr)(input)?, - 0x36 => map(be_u8, Instruction::Istore)(input)?, - 0x3b => success(Instruction::Istore0)(input)?, - 0x3c => success(Instruction::Istore1)(input)?, - 0x3d => success(Instruction::Istore2)(input)?, - 0x3e => success(Instruction::Istore3)(input)?, - 0x64 => success(Instruction::Isub)(input)?, - 0x7c => success(Instruction::Iushr)(input)?, - 0x82 => success(Instruction::Ixor)(input)?, - 0xa8 => map(be_i16, Instruction::Jsr)(input)?, - 0xc9 => map(be_i32, Instruction::JsrW)(input)?, - 0x8a => success(Instruction::L2d)(input)?, - 0x89 => success(Instruction::L2f)(input)?, - 0x88 => success(Instruction::L2i)(input)?, - 0x61 => success(Instruction::Ladd)(input)?, - 0x2f => success(Instruction::Laload)(input)?, - 0x7f => success(Instruction::Land)(input)?, - 0x50 => success(Instruction::Lastore)(input)?, - 0x94 => success(Instruction::Lcmp)(input)?, - 0x09 => success(Instruction::Lconst0)(input)?, - 0x0a => success(Instruction::Lconst1)(input)?, - 0x12 => map(be_u8, Instruction::Ldc)(input)?, - 0x13 => map(be_u16, Instruction::LdcW)(input)?, - 0x14 => map(be_u16, Instruction::Ldc2W)(input)?, - 0x6d => success(Instruction::Ldiv)(input)?, - 0x16 => map(be_u8, Instruction::Lload)(input)?, - 0x1e => success(Instruction::Lload0)(input)?, - 0x1f => success(Instruction::Lload1)(input)?, - 0x20 => success(Instruction::Lload2)(input)?, - 0x21 => success(Instruction::Lload3)(input)?, - 0x69 => success(Instruction::Lmul)(input)?, - 0x75 => success(Instruction::Lneg)(input)?, - 0xab => preceded(align(address + 1), lookupswitch_parser)(input)?, - 0x81 => success(Instruction::Lor)(input)?, - 0x71 => success(Instruction::Lrem)(input)?, - 0xad => success(Instruction::Lreturn)(input)?, - 0x79 => success(Instruction::Lshl)(input)?, - 0x7b => success(Instruction::Lshr)(input)?, - 0x37 => map(be_u8, Instruction::Lstore)(input)?, - 0x3f => success(Instruction::Lstore0)(input)?, - 0x40 => success(Instruction::Lstore1)(input)?, - 0x41 => success(Instruction::Lstore2)(input)?, - 0x42 => success(Instruction::Lstore3)(input)?, - 0x65 => success(Instruction::Lsub)(input)?, - 0x7d => success(Instruction::Lushr)(input)?, - 0x83 => success(Instruction::Lxor)(input)?, - 0xc2 => success(Instruction::Monitorenter)(input)?, - 0xc3 => success(Instruction::Monitorexit)(input)?, - 0xc5 => map(pair(be_u16, be_u8), |(index, dimensions)| { - Instruction::Multianewarray { index, dimensions } - })(input)?, - 0xbb => map(be_u16, Instruction::New)(input)?, - 0xbc => map(be_u8, Instruction::Newarray)(input)?, - 0x00 => success(Instruction::Nop)(input)?, - 0x57 => success(Instruction::Pop)(input)?, - 0x58 => success(Instruction::Pop2)(input)?, - 0xb5 => map(be_u16, Instruction::Putfield)(input)?, - 0xb3 => map(be_u16, Instruction::Putstatic)(input)?, - 0xa9 => map(be_u8, Instruction::Ret)(input)?, - 0xb1 => success(Instruction::Return)(input)?, - 0x35 => success(Instruction::Saload)(input)?, - 0x56 => success(Instruction::Sastore)(input)?, - 0x11 => map(be_i16, Instruction::Sipush)(input)?, - 0x5f => success(Instruction::Swap)(input)?, - 0xaa => preceded(align(address + 1), tableswitch_parser)(input)?, - 0xc4 => { - let (input, b1) = be_u8(input)?; - match b1 { - 0x19 => map(be_u16, Instruction::AloadWide)(input)?, - 0x3a => map(be_u16, Instruction::AstoreWide)(input)?, - 0x18 => map(be_u16, Instruction::DloadWide)(input)?, - 0x39 => map(be_u16, Instruction::DstoreWide)(input)?, - 0x17 => map(be_u16, Instruction::FloadWide)(input)?, - 0x38 => map(be_u16, Instruction::FstoreWide)(input)?, - 0x15 => map(be_u16, Instruction::IloadWide)(input)?, - 0x36 => map(be_u16, Instruction::IstoreWide)(input)?, - 0x16 => map(be_u16, Instruction::LloadWide)(input)?, - 0x37 => map(be_u16, Instruction::LstoreWide)(input)?, - 0xa9 => map(be_u16, Instruction::RetWide)(input)?, - 0x84 => map(pair(be_u16, be_i16), |(index, value)| { - Instruction::IincWide { index, value } - })(input)?, - _ => fail(input)?, - } - } - _ => fail(input)?, - }; - Ok((input, instruction)) -} - -pub fn local_variable_table_parser( - input: &[u8], -) -> Result<(&[u8], LocalVariableTableAttribute), Err<&[u8]>> { - let (input, local_variable_table_length) = be_u16(input)?; - let (input, items) = count( - variable_table_item_parser, - local_variable_table_length as usize, - )(input)?; - Ok(( - input, - LocalVariableTableAttribute { - local_variable_table_length, - items, - }, - )) -} - -pub fn variable_table_item_parser( - input: &[u8], -) -> Result<(&[u8], LocalVariableTableItem), Err<&[u8]>> { - let (input, start_pc) = be_u16(input)?; - let (input, length) = be_u16(input)?; - let (input, name_index) = be_u16(input)?; - let (input, descriptor_index) = be_u16(input)?; - let (input, index) = be_u16(input)?; - Ok(( - input, - LocalVariableTableItem { - start_pc, - length, - name_index, - descriptor_index, - index, - }, - )) -} - -pub fn local_variable_type_table_parser( - input: &[u8], -) -> Result<(&[u8], LocalVariableTypeTableAttribute), Err<&[u8]>> { - let (input, local_variable_type_table_length) = be_u16(input)?; - let (input, local_variable_type_table) = count( - local_variable_type_table_item_parser, - local_variable_type_table_length as usize, - )(input)?; - Ok(( - input, - LocalVariableTypeTableAttribute { - local_variable_type_table_length, - local_variable_type_table, - }, - )) -} - -pub fn local_variable_type_table_item_parser( - input: &[u8], -) -> Result<(&[u8], LocalVariableTypeTableItem), Err<&[u8]>> { - let (input, start_pc) = be_u16(input)?; - let (input, length) = be_u16(input)?; - let (input, name_index) = be_u16(input)?; - let (input, signature_index) = be_u16(input)?; - let (input, index) = be_u16(input)?; - Ok(( - input, - LocalVariableTypeTableItem { - start_pc, - length, - name_index, - signature_index, - index, - }, - )) -} diff --git a/src/code_attribute/types.rs b/src/code_attribute/types.rs index 5d336a5..ccf0922 100644 --- a/src/code_attribute/types.rs +++ b/src/code_attribute/types.rs @@ -1,247 +1,464 @@ +use binrw::binrw; + #[derive(Clone, Debug, Eq, PartialEq)] +#[binrw] +#[br(return_unexpected_error, import { address: u32 })] +#[bw(import { address: u32 })] +#[brw(big)] pub enum Instruction { + #[brw(magic = 0x00u8)] + Nop, + #[brw(magic = 0x32u8)] Aaload, + #[brw(magic = 0x53u8)] Aastore, + #[brw(magic = 0x01u8)] Aconstnull, + #[brw(magic = 0x19u8)] Aload(u8), - AloadWide(u16), + #[brw(magic = 0x2au8)] Aload0, + #[brw(magic = 0x2bu8)] Aload1, + #[brw(magic = 0x2cu8)] Aload2, + #[brw(magic = 0x2du8)] Aload3, + #[brw(magic = 0xbdu8)] Anewarray(u16), + #[brw(magic = 0xb0u8)] Areturn, + #[brw(magic = 0xbeu8)] Arraylength, + #[brw(magic = 0x3au8)] Astore(u8), - AstoreWide(u16), + #[brw(magic = 0x4bu8)] Astore0, + #[brw(magic = 0x4cu8)] Astore1, + #[brw(magic = 0x4du8)] Astore2, + #[brw(magic = 0x4eu8)] Astore3, + #[brw(magic = 0xbfu8)] Athrow, + #[brw(magic = 0x33u8)] Baload, + #[brw(magic = 0x54u8)] Bastore, + #[brw(magic = 0x10u8)] Bipush(i8), + #[brw(magic = 0x34u8)] Caload, + #[brw(magic = 0x55u8)] Castore, + #[brw(magic = 0xc0u8)] Checkcast(u16), + #[brw(magic = 0x90u8)] D2f, + #[brw(magic = 0x8eu8)] D2i, + #[brw(magic = 0x8fu8)] D2l, + #[brw(magic = 0x63u8)] Dadd, + #[brw(magic = 0x31u8)] Daload, + #[brw(magic = 0x52u8)] Dastore, + #[brw(magic = 0x98u8)] Dcmpg, + #[brw(magic = 0x97u8)] Dcmpl, + #[brw(magic = 0x0eu8)] Dconst0, + #[brw(magic = 0x0fu8)] Dconst1, + #[brw(magic = 0x6fu8)] Ddiv, + #[brw(magic = 0x18u8)] Dload(u8), - DloadWide(u16), + #[brw(magic = 0x26u8)] Dload0, + #[brw(magic = 0x27u8)] Dload1, + #[brw(magic = 0x28u8)] Dload2, + #[brw(magic = 0x29u8)] Dload3, + #[brw(magic = 0x6bu8)] Dmul, + #[brw(magic = 0x77u8)] Dneg, + #[brw(magic = 0x73u8)] Drem, + #[brw(magic = 0xafu8)] Dreturn, + #[brw(magic = 0x39u8)] Dstore(u8), - DstoreWide(u16), + #[brw(magic = 0x47u8)] Dstore0, + #[brw(magic = 0x48u8)] Dstore1, + #[brw(magic = 0x49u8)] Dstore2, + #[brw(magic = 0x4au8)] Dstore3, + #[brw(magic = 0x67u8)] Dsub, + #[brw(magic = 0x59u8)] Dup, + #[brw(magic = 0x5au8)] Dupx1, + #[brw(magic = 0x5bu8)] Dupx2, + #[brw(magic = 0x5cu8)] Dup2, + #[brw(magic = 0x5du8)] Dup2x1, + #[brw(magic = 0x5eu8)] Dup2x2, + #[brw(magic = 0x8du8)] F2d, + #[brw(magic = 0x8bu8)] F2i, + #[brw(magic = 0x8cu8)] F2l, + #[brw(magic = 0x62u8)] Fadd, + #[brw(magic = 0x30u8)] Faload, + #[brw(magic = 0x51u8)] Fastore, + #[brw(magic = 0x96u8)] Fcmpg, + #[brw(magic = 0x95u8)] Fcmpl, + #[brw(magic = 0xbu8)] Fconst0, + #[brw(magic = 0xcu8)] Fconst1, + #[brw(magic = 0xdu8)] Fconst2, + #[brw(magic = 0x6eu8)] Fdiv, + #[brw(magic = 0x17u8)] Fload(u8), - FloadWide(u16), + #[brw(magic = 0x22u8)] Fload0, + #[brw(magic = 0x23u8)] Fload1, + #[brw(magic = 0x24u8)] Fload2, + #[brw(magic = 0x25u8)] Fload3, + #[brw(magic = 0x6au8)] Fmul, + #[brw(magic = 0x76u8)] Fneg, + #[brw(magic = 0x72u8)] Frem, + #[brw(magic = 0xaeu8)] Freturn, + #[brw(magic = 0x38u8)] Fstore(u8), - FstoreWide(u16), + #[brw(magic = 0x43u8)] Fstore0, + #[brw(magic = 0x44u8)] Fstore1, + #[brw(magic = 0x45u8)] Fstore2, + #[brw(magic = 0x46u8)] Fstore3, + #[brw(magic = 0x66u8)] Fsub, + #[brw(magic = 0xb4u8)] Getfield(u16), + #[brw(magic = 0xb2u8)] Getstatic(u16), + #[brw(magic = 0xa7u8)] Goto(i16), + #[brw(magic = 0xc8u8)] GotoW(i32), + #[brw(magic = 0x91u8)] I2b, + #[brw(magic = 0x92u8)] I2c, + #[brw(magic = 0x87u8)] I2d, + #[brw(magic = 0x86u8)] I2f, + #[brw(magic = 0x85u8)] I2l, + #[brw(magic = 0x93u8)] I2s, + #[brw(magic = 0x60u8)] Iadd, + #[brw(magic = 0x2eu8)] Iaload, + #[brw(magic = 0x7eu8)] Iand, + #[brw(magic = 0x4fu8)] Iastore, + #[brw(magic = 0x2u8)] Iconstm1, + #[brw(magic = 0x3u8)] Iconst0, + #[brw(magic = 0x4u8)] Iconst1, + #[brw(magic = 0x5u8)] Iconst2, + #[brw(magic = 0x6u8)] Iconst3, + #[brw(magic = 0x7u8)] Iconst4, + #[brw(magic = 0x8u8)] Iconst5, + #[brw(magic = 0x6cu8)] Idiv, + #[brw(magic = 0xa5u8)] IfAcmpeq(i16), + #[brw(magic = 0xa6u8)] IfAcmpne(i16), + #[brw(magic = 0x9fu8)] IfIcmpeq(i16), + #[brw(magic = 0xa0u8)] IfIcmpne(i16), + #[brw(magic = 0xa1u8)] IfIcmplt(i16), + #[brw(magic = 0xa2u8)] IfIcmpge(i16), + #[brw(magic = 0xa3u8)] IfIcmpgt(i16), + #[brw(magic = 0xa4u8)] IfIcmple(i16), + #[brw(magic = 0x99u8)] Ifeq(i16), + #[brw(magic = 0x9au8)] Ifne(i16), + #[brw(magic = 0x9bu8)] Iflt(i16), + #[brw(magic = 0x9cu8)] Ifge(i16), + #[brw(magic = 0x9du8)] Ifgt(i16), + #[brw(magic = 0x9eu8)] Ifle(i16), + #[brw(magic = 0xc7u8)] Ifnonnull(i16), + #[brw(magic = 0xc6u8)] Ifnull(i16), - Iinc { - index: u8, - value: i8, - }, - IincWide { - index: u16, - value: i16, - }, + #[brw(magic = 0x84u8)] + Iinc { index: u8, value: i8 }, + #[brw(magic = 0x15u8)] Iload(u8), - IloadWide(u16), + #[brw(magic = 0x1au8)] Iload0, + #[brw(magic = 0x1bu8)] Iload1, + #[brw(magic = 0x1cu8)] Iload2, + #[brw(magic = 0x1du8)] Iload3, + #[brw(magic = 0x68u8)] Imul, + #[brw(magic = 0x74u8)] Ineg, + #[brw(magic = 0xc1u8)] Instanceof(u16), - Invokedynamic(u16), - Invokeinterface { - index: u16, - count: u8, - }, + #[brw(magic = 0xbau8)] + Invokedynamic { index: u16, filler: u16 }, + #[brw(magic = 0xb9u8)] + Invokeinterface { index: u16, count: u8, filler: u8 }, + #[brw(magic = 0xb7u8)] Invokespecial(u16), + #[brw(magic = 0xb8u8)] Invokestatic(u16), + #[brw(magic = 0xb6u8)] Invokevirtual(u16), + #[brw(magic = 0x80u8)] Ior, + #[brw(magic = 0x70u8)] Irem, + #[brw(magic = 0xacu8)] Ireturn, + #[brw(magic = 0x78u8)] Ishl, + #[brw(magic = 0x7au8)] Ishr, + #[brw(magic = 0x36u8)] Istore(u8), - IstoreWide(u16), + #[brw(magic = 0x3bu8)] Istore0, + #[brw(magic = 0x3cu8)] Istore1, + #[brw(magic = 0x3du8)] Istore2, + #[brw(magic = 0x3eu8)] Istore3, + #[brw(magic = 0x64u8)] Isub, + #[brw(magic = 0x7cu8)] Iushr, + #[brw(magic = 0x82u8)] Ixor, + #[brw(magic = 0xa8u8)] Jsr(i16), + #[brw(magic = 0xc9u8)] JsrW(i32), + #[brw(magic = 0x8au8)] L2d, + #[brw(magic = 0x89u8)] L2f, + #[brw(magic = 0x88u8)] L2i, + #[brw(magic = 0x61u8)] Ladd, + #[brw(magic = 0x2fu8)] Laload, + #[brw(magic = 0x7fu8)] Land, + #[brw(magic = 0x50u8)] Lastore, + #[brw(magic = 0x94u8)] Lcmp, + #[brw(magic = 0x09u8)] Lconst0, + #[brw(magic = 0x0au8)] Lconst1, + #[brw(magic = 0x12u8)] Ldc(u8), + #[brw(magic = 0x13u8)] LdcW(u16), + #[brw(magic = 0x14u8)] Ldc2W(u16), + #[brw(magic = 0x6du8)] Ldiv, + #[brw(magic = 0x16u8)] Lload(u8), - LloadWide(u16), + #[brw(magic = 0x1eu8)] Lload0, + #[brw(magic = 0x1fu8)] Lload1, + #[brw(magic = 0x20u8)] Lload2, + #[brw(magic = 0x21u8)] Lload3, + #[brw(magic = 0x69u8)] Lmul, + #[brw(magic = 0x75u8)] Lneg, + #[brw(magic = 0xabu8)] Lookupswitch { + #[brw(pad_before = ((4 - (address + 1) % 4) % 4))] default: i32, + npairs: u32, + #[br(count = npairs)] pairs: Vec<(i32, i32)>, }, + #[brw(magic = 0x81u8)] Lor, + #[brw(magic = 0x71u8)] Lrem, + #[brw(magic = 0xadu8)] Lreturn, + #[brw(magic = 0x79u8)] Lshl, + #[brw(magic = 0x7bu8)] Lshr, + #[brw(magic = 0x37u8)] Lstore(u8), - LstoreWide(u16), + #[brw(magic = 0x3fu8)] Lstore0, + #[brw(magic = 0x40u8)] Lstore1, + #[brw(magic = 0x41u8)] Lstore2, + #[brw(magic = 0x42u8)] Lstore3, + #[brw(magic = 0x65u8)] Lsub, + #[brw(magic = 0x7du8)] Lushr, + #[brw(magic = 0x83u8)] Lxor, + #[brw(magic = 0xc2u8)] Monitorenter, + #[brw(magic = 0xc3u8)] Monitorexit, - Multianewarray { - index: u16, - dimensions: u8, - }, + #[brw(magic = 0xc5u8)] + Multianewarray { index: u16, dimensions: u8 }, + #[brw(magic = 0xbbu8)] New(u16), + #[brw(magic = 0xbcu8)] Newarray(u8), - Nop, + #[brw(magic = 0x57u8)] Pop, + #[brw(magic = 0x58u8)] Pop2, + #[brw(magic = 0xb5u8)] Putfield(u16), + #[brw(magic = 0xb3u8)] Putstatic(u16), + #[brw(magic = 0xa9u8)] Ret(u8), - RetWide(u16), + #[brw(magic = 0xb1u8)] Return, + #[brw(magic = 0x35u8)] Saload, + #[brw(magic = 0x56u8)] Sastore, + #[brw(magic = 0x11u8)] Sipush(i16), + #[brw(magic = 0x5fu8)] Swap, + #[brw(magic = 0xaau8)] Tableswitch { + #[brw(pad_before = ((4 - (address + 1) % 4) % 4))] default: i32, low: i32, high: i32, + #[br(count = high - low + 1)] offsets: Vec, }, + #[brw(magic = b"\xc4\x19")] + AloadWide(u16), + #[brw(magic = b"\xc4\x3a")] + AstoreWide(u16), + #[brw(magic = b"\xc4\x18")] + DloadWide(u16), + #[brw(magic = b"\xc4\x39")] + DstoreWide(u16), + #[brw(magic = b"\xc4\x17")] + FloadWide(u16), + #[brw(magic = b"\xc4\x38")] + FstoreWide(u16), + #[brw(magic = b"\xc4\x15")] + IloadWide(u16), + #[brw(magic = b"\xc4\x36")] + IstoreWide(u16), + #[brw(magic = b"\xc4\x16")] + LloadWide(u16), + #[brw(magic = b"\xc4\x37")] + LstoreWide(u16), + #[brw(magic = b"\xc4\xa9")] + RetWide(u16), + #[brw(magic = b"\xc4\x84")] + IincWide { index: u16, value: i16 }, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct LocalVariableTableAttribute { pub local_variable_table_length: u16, + #[br(count = local_variable_table_length)] pub items: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct LocalVariableTableItem { pub start_pc: u16, pub length: u16, @@ -251,12 +468,17 @@ pub struct LocalVariableTableItem { } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct LocalVariableTypeTableAttribute { pub local_variable_type_table_length: u16, + #[br(count = local_variable_type_table_length)] pub local_variable_type_table: Vec, } #[derive(Clone, Debug)] +#[binrw] +#[brw(big)] pub struct LocalVariableTypeTableItem { pub start_pc: u16, pub length: u16, diff --git a/src/compile/ast.rs b/src/compile/ast.rs new file mode 100644 index 0000000..fd7de26 --- /dev/null +++ b/src/compile/ast.rs @@ -0,0 +1,237 @@ +/// Compile AST types — lightweight syntax tree mapping directly to bytecode. + +#[derive(Clone, Debug, PartialEq)] +pub enum PrimitiveKind { + Int, + Long, + Float, + Double, + Boolean, + Byte, + Char, + Short, + Void, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum TypeName { + Primitive(PrimitiveKind), + Class(String), + Array(Box), +} + +#[derive(Clone, Debug, PartialEq)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Rem, + Shl, + Shr, + Ushr, + BitAnd, + BitOr, + BitXor, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum CompareOp { + Eq, + Ne, + Lt, + Le, + Gt, + Ge, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum UnaryOp { + Neg, + BitNot, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct SwitchExprCase { + pub values: Vec, + pub expr: CExpr, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct LambdaParam { + pub ty: Option, + pub name: String, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum LambdaBody { + Expr(Box), + Block(Vec), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct SwitchCase { + pub values: Vec, + pub body: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct CatchClause { + pub exception_types: Vec, + pub var_name: String, + pub body: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum CStmt { + LocalDecl { + ty: TypeName, + name: String, + init: Option, + }, + ExprStmt(CExpr), + Return(Option), + If { + condition: CExpr, + then_body: Vec, + else_body: Option>, + }, + While { + condition: CExpr, + body: Vec, + }, + For { + init: Option>, + condition: Option, + update: Option>, + body: Vec, + }, + Block(Vec), + Throw(CExpr), + Break, + Continue, + Switch { + expr: CExpr, + cases: Vec, + default_body: Option>, + }, + TryCatch { + try_body: Vec, + catches: Vec, + finally_body: Option>, + }, + ForEach { + element_type: TypeName, + var_name: String, + iterable: CExpr, + body: Vec, + }, + Synchronized { + lock_expr: CExpr, + body: Vec, + }, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum CExpr { + IntLiteral(i64), + FloatLiteral(f64), + BoolLiteral(bool), + StringLiteral(String), + CharLiteral(char), + NullLiteral, + LongLiteral(i64), + DoubleLiteral(f64), + Ident(String), + This, + BinaryOp { + op: BinOp, + left: Box, + right: Box, + }, + UnaryOp { + op: UnaryOp, + operand: Box, + }, + Comparison { + op: CompareOp, + left: Box, + right: Box, + }, + LogicalAnd(Box, Box), + LogicalOr(Box, Box), + LogicalNot(Box), + Assign { + target: Box, + value: Box, + }, + CompoundAssign { + op: BinOp, + target: Box, + value: Box, + }, + PreIncrement(Box), + PreDecrement(Box), + PostIncrement(Box), + PostDecrement(Box), + MethodCall { + object: Option>, + name: String, + args: Vec, + }, + StaticMethodCall { + class_name: String, + name: String, + args: Vec, + }, + FieldAccess { + object: Box, + name: String, + }, + StaticFieldAccess { + class_name: String, + name: String, + }, + NewObject { + class_name: String, + args: Vec, + }, + NewArray { + element_type: TypeName, + size: Box, + }, + NewMultiArray { + element_type: TypeName, + dimensions: Vec, + }, + ArrayAccess { + array: Box, + index: Box, + }, + Cast { + ty: TypeName, + operand: Box, + }, + Instanceof { + operand: Box, + ty: TypeName, + }, + Ternary { + condition: Box, + then_expr: Box, + else_expr: Box, + }, + SwitchExpr { + expr: Box, + cases: Vec, + default_expr: Box, + }, + Lambda { + params: Vec, + body: LambdaBody, + }, + MethodRef { + class_name: String, + method_name: String, + }, +} diff --git a/src/compile/codegen.rs b/src/compile/codegen.rs new file mode 100644 index 0000000..d9c938f --- /dev/null +++ b/src/compile/codegen.rs @@ -0,0 +1,4388 @@ +use crate::ClassFile; +use crate::attribute_info::{ + AttributeInfo, AttributeInfoVariant, BootstrapMethod, BootstrapMethodsAttribute, CodeAttribute, +}; +use crate::code_attribute::Instruction; +use crate::decompile::descriptor::{JvmType, parse_method_descriptor}; +use crate::decompile::util::instruction_byte_size; +use crate::method_info::{MethodAccessFlags, MethodInfo}; + +use super::CompileError; +use super::ast::*; +use super::stackmap::{FrameTracker, VType}; + +/// Tracks local variable allocation. +struct LocalAllocator { + /// (name, type, slot, vtype) + locals: Vec<(String, TypeName, u16, VType)>, + next_slot: u16, +} + +impl LocalAllocator { + fn new( + is_static: bool, + method_descriptor: &str, + class_file: &mut ClassFile, + param_names: &[Option], + ) -> Result { + let mut next_slot: u16 = 0; + let mut locals = Vec::new(); + + if !is_static { + let vtype = VType::Object(class_file.this_class); + locals.push(("this".into(), TypeName::Class("this".into()), 0, vtype)); + next_slot = 1; + } + + // Parse method descriptor to pre-allocate parameter slots + let (params, _) = parse_method_descriptor(method_descriptor).ok_or_else(|| { + CompileError::CodegenError { + message: format!("invalid method descriptor: {}", method_descriptor), + } + })?; + + for (i, param) in params.iter().enumerate() { + let ty = jvm_type_to_type_name(param); + let vtype = jvm_type_to_vtype_resolved(param, class_file); + let slot = next_slot; + // Always register positional name arg{i} + let positional = format!("arg{}", i); + locals.push((positional.clone(), ty.clone(), slot, vtype.clone())); + // If a debug name is available and differs from the positional name, register as alias + if let Some(Some(debug_name)) = param_names.get(i) + && debug_name != &positional + { + locals.push((debug_name.clone(), ty, slot, vtype)); + } + next_slot += if param.is_wide() { 2 } else { 1 }; + } + + Ok(LocalAllocator { locals, next_slot }) + } + + fn allocate(&mut self, name: &str, ty: &TypeName) -> u16 { + let vtype = type_name_to_vtype(ty); + self.allocate_with_vtype(name, ty, vtype) + } + + fn allocate_with_vtype(&mut self, name: &str, ty: &TypeName, vtype: VType) -> u16 { + let slot = self.next_slot; + let width = type_slot_width(ty); + self.locals + .push((name.to_string(), ty.clone(), slot, vtype)); + self.next_slot += width; + slot + } + + fn find(&self, name: &str) -> Option<(u16, &TypeName)> { + // Search from the end to support shadowing + for (n, ty, slot, _) in self.locals.iter().rev() { + if n == name { + return Some((*slot, ty)); + } + } + None + } + + fn max_locals(&self) -> u16 { + self.next_slot + } + + /// Save the current locals state for later restoration (scope support). + /// `next_slot` is NOT saved — it only ever increases to ensure max_locals is correct. + fn save(&self) -> Vec<(String, TypeName, u16, VType)> { + self.locals.clone() + } + + /// Restore locals to a previous state (scope exit). Body-scoped locals are removed, + /// but `next_slot` stays at its high-water mark so max_locals remains correct. + fn restore(&mut self, saved: Vec<(String, TypeName, u16, VType)>) { + self.locals = saved; + } + + /// Get current locals as VType array for StackMapTable generation. + /// Returns the compressed format required by StackMapTable: Long/Double + /// implicitly cover 2 slots, so continuation Top entries are omitted. + fn current_locals_vtypes(&self) -> Vec { + let mut slot_vtypes = vec![VType::Top; self.next_slot as usize]; + for (_, _, slot, vtype) in &self.locals { + slot_vtypes[*slot as usize] = vtype.clone(); + } + // Convert slot-indexed array to StackMapTable format: + // Skip the implicit continuation slot after Long/Double. + let mut vtypes = Vec::new(); + let mut i = 0; + while i < slot_vtypes.len() { + vtypes.push(slot_vtypes[i].clone()); + if slot_vtypes[i] == VType::Long || slot_vtypes[i] == VType::Double { + i += 2; // skip continuation slot + } else { + i += 1; + } + } + // Trim trailing Top values + while vtypes.last() == Some(&VType::Top) { + vtypes.pop(); + } + vtypes + } +} + +struct BreakableContext { + break_label: usize, + is_loop: bool, + continue_label: Option, +} + +enum SwitchPatchKind { + Table { + low: i32, + high: i32, + /// Labels for offsets[0..=(high-low)], then default_label + case_labels: Vec, + default_label: usize, + }, + Lookup { + /// (match_value, label) + pairs: Vec<(i32, usize)>, + default_label: usize, + }, +} + +struct SwitchPatch { + instr_idx: usize, + kind: SwitchPatchKind, +} + +struct PendingExceptionEntry { + start_label: usize, + end_label: usize, + handler_label: usize, + catch_type: u16, +} + +pub struct CodeGenerator<'a> { + class_file: &'a mut ClassFile, + instructions: Vec, + locals: LocalAllocator, + labels: Vec>, // label_id → instruction index (None = unresolved) + patches: Vec<(usize, usize)>, // (instruction_index, target_label_id) + switch_patches: Vec, + breakable_stack: Vec, + pending_exceptions: Vec, + is_static: bool, + return_type: JvmType, + frame_tracker: Option, + /// Labels that are exception handler entry points: (label_id, exception_vtype, locals_at_try_start). + exception_handler_labels: Vec<(usize, VType, Vec)>, + /// Labels whose frame locals should be overridden (not taken from current allocator). + label_locals_override: Vec<(usize, Vec)>, + /// Labels whose frame stack should be overridden (not empty). + label_stack_override: Vec<(usize, Vec)>, +} + +impl<'a> CodeGenerator<'a> { + pub fn new( + class_file: &'a mut ClassFile, + is_static: bool, + method_descriptor: &str, + param_names: &[Option], + ) -> Result { + Self::new_with_options(class_file, is_static, method_descriptor, false, param_names) + } + + pub fn new_with_options( + class_file: &'a mut ClassFile, + is_static: bool, + method_descriptor: &str, + generate_stack_map_table: bool, + param_names: &[Option], + ) -> Result { + let locals = LocalAllocator::new(is_static, method_descriptor, class_file, param_names)?; + let (params, ret) = parse_method_descriptor(method_descriptor).ok_or_else(|| { + CompileError::CodegenError { + message: format!("invalid method descriptor: {}", method_descriptor), + } + })?; + + let frame_tracker = if generate_stack_map_table { + // Build initial locals VTypes from method descriptor + let mut initial = Vec::new(); + if !is_static { + // 'this' reference — resolve the class + initial.push(VType::Object(class_file.this_class)); + } + for param in ¶ms { + initial.push(jvm_type_to_vtype_resolved(param, class_file)); + // Note: Long/Double implicitly cover 2 local variable slots in + // StackMapTable encoding. Do NOT add explicit Top continuation — + // the JVM spec says the next entry maps to slot N+2 automatically. + } + Some(FrameTracker::new(initial)) + } else { + None + }; + + Ok(CodeGenerator { + class_file, + instructions: Vec::new(), + locals, + labels: Vec::new(), + patches: Vec::new(), + switch_patches: Vec::new(), + breakable_stack: Vec::new(), + pending_exceptions: Vec::new(), + is_static, + return_type: ret, + frame_tracker, + exception_handler_labels: Vec::new(), + label_locals_override: Vec::new(), + label_stack_override: Vec::new(), + }) + } + + fn new_label(&mut self) -> usize { + let id = self.labels.len(); + self.labels.push(None); + id + } + + fn bind_label(&mut self, label_id: usize) { + let instr_idx = self.instructions.len(); + self.labels[label_id] = Some(instr_idx); + + // Record frame snapshot for StackMapTable generation + if self.frame_tracker.is_some() { + // Check if this label is an exception handler entry point + let exception_info = self + .exception_handler_labels + .iter() + .find(|(lid, _, _)| *lid == label_id) + .map(|(_, vtype, saved_locals)| (vtype.clone(), saved_locals.clone())); + + let (locals, stack) = if let Some((vtype, saved_locals)) = exception_info { + // Exception handlers use the locals from try-start, not current allocator state + (saved_locals, vec![vtype]) + } else { + // Check for explicit locals override (e.g., merge points after try-catch) + let overridden_locals = self + .label_locals_override + .iter() + .find(|(lid, _)| *lid == label_id) + .map(|(_, locals)| locals.clone()); + let locals = + overridden_locals.unwrap_or_else(|| self.locals.current_locals_vtypes()); + // Check for explicit stack override (e.g., expression-level merge points) + let stack = self + .label_stack_override + .iter() + .find(|(lid, _)| *lid == label_id) + .map(|(_, stack)| stack.clone()) + .unwrap_or_default(); + (locals, stack) + }; + + // We need to compute the bytecode offset for this instruction index. + // Since instructions haven't been patched yet, compute from current instructions. + let offset = compute_byte_offset_at(&self.instructions, instr_idx); + + if let Some(ref mut tracker) = self.frame_tracker { + tracker.record_frame(offset, locals, stack); + } + } + } + + fn emit(&mut self, instr: Instruction) -> usize { + let idx = self.instructions.len(); + self.instructions.push(instr); + idx + } + + fn emit_branch(&mut self, instr_fn: fn(i16) -> Instruction, target_label: usize) { + let idx = self.emit(instr_fn(0)); // placeholder + self.patches.push((idx, target_label)); + } + + fn emit_goto(&mut self, target_label: usize) { + self.emit_branch(Instruction::Goto, target_label); + } + + /// Returns true if the last emitted instruction is an unconditional control transfer + /// (goto, return, athrow). Used to avoid emitting dead code that would require + /// a StackMapTable frame the JVM verifier would complain about. + fn last_is_unconditional_transfer(&self) -> bool { + match self.instructions.last() { + Some(Instruction::Goto(_)) + | Some(Instruction::GotoW(_)) + | Some(Instruction::Return) + | Some(Instruction::Ireturn) + | Some(Instruction::Lreturn) + | Some(Instruction::Freturn) + | Some(Instruction::Dreturn) + | Some(Instruction::Areturn) + | Some(Instruction::Athrow) => true, + _ => false, + } + } + + /// Resolve all branch patches using byte addresses. + fn resolve_patches(&mut self) -> Result<(), CompileError> { + // Compute byte addresses for each instruction + let addresses = compute_byte_addresses(&self.instructions); + + // Compute end address (address after last instruction) + let end_addr = if self.instructions.is_empty() { + 0i32 + } else { + let last = *addresses.last().unwrap(); + (last + instruction_byte_size(self.instructions.last().unwrap(), last as u32)) as i32 + }; + + for &(instr_idx, label_id) in &self.patches { + let source_addr = addresses[instr_idx] as i32; + let target_addr = resolve_label_addr(label_id, &self.labels, &addresses, end_addr)?; + let offset = target_addr - source_addr; + let offset16 = offset as i16; + + self.instructions[instr_idx] = + patch_branch_offset(&self.instructions[instr_idx], offset16)?; + } + + // Resolve switch patches + for patch in &self.switch_patches { + let source_addr = addresses[patch.instr_idx] as i32; + match &patch.kind { + SwitchPatchKind::Table { + low, + high, + case_labels, + default_label, + } => { + let default_offset = + resolve_label_addr(*default_label, &self.labels, &addresses, end_addr)? + - source_addr; + let mut offsets = Vec::new(); + for label_id in case_labels { + let addr = + resolve_label_addr(*label_id, &self.labels, &addresses, end_addr)?; + offsets.push(addr - source_addr); + } + self.instructions[patch.instr_idx] = Instruction::Tableswitch { + default: default_offset, + low: *low, + high: *high, + offsets, + }; + } + SwitchPatchKind::Lookup { + pairs, + default_label, + } => { + let default_offset = + resolve_label_addr(*default_label, &self.labels, &addresses, end_addr)? + - source_addr; + let mut resolved_pairs = Vec::new(); + for (value, label_id) in pairs { + let addr = + resolve_label_addr(*label_id, &self.labels, &addresses, end_addr)?; + resolved_pairs.push((*value, addr - source_addr)); + } + self.instructions[patch.instr_idx] = Instruction::Lookupswitch { + default: default_offset, + npairs: resolved_pairs.len() as u32, + pairs: resolved_pairs, + }; + } + } + } + + Ok(()) + } + + pub fn generate_body(&mut self, stmts: &[CStmt]) -> Result<(), CompileError> { + for stmt in stmts { + self.gen_stmt(stmt)?; + } + // If the method returns void, ensure there's a trailing return. + // Any label pointing to `instructions.len()` (end of code) needs a + // valid instruction there, so always emit Return for void methods + // when the last emitted instruction isn't already a return/throw, + // OR when there are labels that point to the end of instructions. + if self.return_type == JvmType::Void { + let has_label_at_end = self.labels.contains(&Some(self.instructions.len())); + let needs_return = self.instructions.is_empty() + || has_label_at_end + || !matches!( + self.instructions.last(), + Some(Instruction::Return) + | Some(Instruction::Ireturn) + | Some(Instruction::Lreturn) + | Some(Instruction::Freturn) + | Some(Instruction::Dreturn) + | Some(Instruction::Areturn) + | Some(Instruction::Athrow) + ); + if needs_return { + self.emit(Instruction::Return); + } + } + Ok(()) + } + + pub fn finish(mut self) -> Result { + self.resolve_patches()?; + let max_stack = super::stack_calc::compute_max_stack(&self.instructions); + let max_locals = self.locals.max_locals(); + let exception_table = self.build_exception_table()?; + let stack_map_table = self.frame_tracker.take().and_then(|t| t.build()); + Ok(super::GeneratedCode { + instructions: self.instructions, + max_stack, + max_locals, + exception_table, + stack_map_table, + }) + } + + fn build_exception_table( + &self, + ) -> Result, CompileError> { + use crate::attribute_info::ExceptionEntry; + let addresses = compute_byte_addresses(&self.instructions); + let end_addr = { + if self.instructions.is_empty() { + 0u16 + } else { + let last = addresses.last().copied().unwrap_or(0); + let last_instr = &self.instructions[self.instructions.len() - 1]; + (last + instruction_byte_size(last_instr, last as u32)) as u16 + } + }; + + let mut entries = Vec::new(); + for pending in &self.pending_exceptions { + let start_instr = + self.labels[pending.start_label].ok_or_else(|| CompileError::CodegenError { + message: "unresolved exception start label".into(), + })?; + let end_instr = + self.labels[pending.end_label].ok_or_else(|| CompileError::CodegenError { + message: "unresolved exception end label".into(), + })?; + let handler_instr = + self.labels[pending.handler_label].ok_or_else(|| CompileError::CodegenError { + message: "unresolved exception handler label".into(), + })?; + + let start_pc = if start_instr < addresses.len() { + addresses[start_instr] as u16 + } else { + end_addr + }; + let end_pc = if end_instr < addresses.len() { + addresses[end_instr] as u16 + } else { + end_addr + }; + let handler_pc = if handler_instr < addresses.len() { + addresses[handler_instr] as u16 + } else { + end_addr + }; + + entries.push(ExceptionEntry { + start_pc, + end_pc, + handler_pc, + catch_type: pending.catch_type, + }); + } + Ok(entries) + } + + // --- Statement codegen --- + + fn gen_stmt(&mut self, stmt: &CStmt) -> Result<(), CompileError> { + match stmt { + CStmt::LocalDecl { ty, name, init } => { + let resolved_ty = if is_var_sentinel(ty) { + match init { + Some(expr) => self.infer_expr_type(expr), + None => { + return Err(CompileError::CodegenError { + message: "'var' requires an initializer".into(), + }); + } + } + } else { + ty.clone() + }; + let vtype = type_name_to_vtype_resolved(&resolved_ty, self.class_file); + if let Some(expr) = init { + // Generate initializer BEFORE allocating the slot so that + // branch targets inside the initializer (ternaries, switch + // expressions, comparisons) don't include the unassigned + // local in their StackMapTable frames. + self.gen_expr(expr)?; + let slot = self.locals.allocate_with_vtype(name, &resolved_ty, vtype); + self.emit_store(&resolved_ty, slot); + } else { + self.locals.allocate_with_vtype(name, &resolved_ty, vtype); + } + Ok(()) + } + CStmt::ExprStmt(expr) => { + self.gen_expr(expr)?; + // Pop the value if the expression leaves one on the stack + if self.expr_leaves_value(expr) { + let ty = self.infer_expr_type(expr); + match &ty { + TypeName::Primitive(PrimitiveKind::Long) + | TypeName::Primitive(PrimitiveKind::Double) => { + self.emit(Instruction::Pop2); + } + _ => { + self.emit(Instruction::Pop); + } + } + } + Ok(()) + } + CStmt::Return(None) => { + self.emit(Instruction::Return); + Ok(()) + } + CStmt::Return(Some(expr)) => { + self.gen_expr(expr)?; + let ret_instr = match &self.return_type { + JvmType::Int + | JvmType::Boolean + | JvmType::Byte + | JvmType::Char + | JvmType::Short => Instruction::Ireturn, + JvmType::Long => Instruction::Lreturn, + JvmType::Float => Instruction::Freturn, + JvmType::Double => Instruction::Dreturn, + JvmType::Reference(_) | JvmType::Array(_) | JvmType::Null => { + Instruction::Areturn + } + JvmType::Void => Instruction::Return, + JvmType::Unknown => Instruction::Areturn, + }; + self.emit(ret_instr); + Ok(()) + } + CStmt::If { + condition, + then_body, + else_body, + } => { + let false_label = self.new_label(); + let pre_branch_locals = self.locals.current_locals_vtypes(); + self.label_locals_override + .push((false_label, pre_branch_locals.clone())); + self.gen_condition(condition, false_label, false)?; + let saved = self.locals.save(); + for s in then_body { + self.gen_stmt(s)?; + } + if let Some(else_stmts) = else_body { + let end_label = self.new_label(); + self.label_locals_override + .push((end_label, pre_branch_locals)); + self.emit_goto(end_label); + self.locals.restore(saved.clone()); + self.bind_label(false_label); + for s in else_stmts { + self.gen_stmt(s)?; + } + self.locals.restore(saved); + self.bind_label(end_label); + } else { + self.locals.restore(saved); + self.bind_label(false_label); + } + Ok(()) + } + CStmt::While { condition, body } => { + let top_label = self.new_label(); + let end_label = self.new_label(); + let pre_body_locals = self.locals.current_locals_vtypes(); + self.label_locals_override + .push((end_label, pre_body_locals)); + self.breakable_stack.push(BreakableContext { + break_label: end_label, + is_loop: true, + continue_label: Some(top_label), + }); + self.bind_label(top_label); + self.gen_condition(condition, end_label, false)?; + let saved = self.locals.save(); + for s in body { + self.gen_stmt(s)?; + } + self.locals.restore(saved); + self.emit_goto(top_label); + self.bind_label(end_label); + self.breakable_stack.pop(); + Ok(()) + } + CStmt::For { + init, + condition, + update, + body, + } => { + if let Some(init_stmt) = init { + self.gen_stmt(init_stmt)?; + } + let top_label = self.new_label(); + let update_label = self.new_label(); + let end_label = self.new_label(); + let pre_body_locals = self.locals.current_locals_vtypes(); + self.label_locals_override + .push((end_label, pre_body_locals.clone())); + self.label_locals_override + .push((update_label, pre_body_locals)); + self.breakable_stack.push(BreakableContext { + break_label: end_label, + is_loop: true, + continue_label: Some(update_label), + }); + self.bind_label(top_label); + if let Some(cond) = condition { + self.gen_condition(cond, end_label, false)?; + } + let saved = self.locals.save(); + for s in body { + self.gen_stmt(s)?; + } + self.locals.restore(saved); + self.bind_label(update_label); + if let Some(upd) = update { + self.gen_stmt(upd)?; + } + self.emit_goto(top_label); + self.bind_label(end_label); + self.breakable_stack.pop(); + Ok(()) + } + CStmt::Block(stmts) => { + let saved = self.locals.save(); + for s in stmts { + self.gen_stmt(s)?; + } + self.locals.restore(saved); + Ok(()) + } + CStmt::Throw(expr) => { + self.gen_expr(expr)?; + self.emit(Instruction::Athrow); + Ok(()) + } + CStmt::Break => { + let label = self + .breakable_stack + .last() + .ok_or_else(|| CompileError::CodegenError { + message: "break outside loop or switch".into(), + })? + .break_label; + self.emit_goto(label); + Ok(()) + } + CStmt::Continue => { + // Search backwards for the first loop context + let label = self + .breakable_stack + .iter() + .rev() + .find(|ctx| ctx.is_loop) + .and_then(|ctx| ctx.continue_label) + .ok_or_else(|| CompileError::CodegenError { + message: "continue outside loop".into(), + })?; + self.emit_goto(label); + Ok(()) + } + CStmt::Switch { + expr, + cases, + default_body, + } => { + self.gen_switch(expr, cases, default_body.as_deref())?; + Ok(()) + } + CStmt::TryCatch { + try_body, + catches, + finally_body, + } => { + self.gen_try_catch(try_body, catches, finally_body.as_deref())?; + Ok(()) + } + CStmt::ForEach { + element_type, + var_name, + iterable, + body, + } => { + self.gen_foreach(element_type, var_name, iterable, body)?; + Ok(()) + } + CStmt::Synchronized { lock_expr, body } => { + self.gen_synchronized(lock_expr, body)?; + Ok(()) + } + } + } + + // --- Expression codegen --- + + fn gen_expr(&mut self, expr: &CExpr) -> Result<(), CompileError> { + match expr { + CExpr::IntLiteral(v) => { + self.emit_int_const(*v); + Ok(()) + } + CExpr::LongLiteral(v) => { + self.emit_long_const(*v); + Ok(()) + } + CExpr::FloatLiteral(v) => { + self.emit_float_const(*v as f32); + Ok(()) + } + CExpr::DoubleLiteral(v) => { + self.emit_double_const(*v); + Ok(()) + } + CExpr::BoolLiteral(b) => { + self.emit(if *b { + Instruction::Iconst1 + } else { + Instruction::Iconst0 + }); + Ok(()) + } + CExpr::StringLiteral(s) => { + let cp_idx = self.class_file.get_or_add_string(s); + self.emit_ldc(cp_idx); + Ok(()) + } + CExpr::CharLiteral(c) => { + self.emit_int_const(*c as i64); + Ok(()) + } + CExpr::NullLiteral => { + self.emit(Instruction::Aconstnull); + Ok(()) + } + CExpr::Ident(name) => { + let (slot, ty) = + self.locals + .find(name) + .ok_or_else(|| CompileError::CodegenError { + message: format!("undefined variable: {}", name), + })?; + let ty = ty.clone(); + self.emit_load(&ty, slot); + Ok(()) + } + CExpr::This => { + if self.is_static { + return Err(CompileError::CodegenError { + message: "'this' not available in static method".into(), + }); + } + self.emit(Instruction::Aload0); + Ok(()) + } + CExpr::BinaryOp { op, left, right } => { + let left_ty = self.infer_expr_type(left); + let right_ty = self.infer_expr_type(right); + + // String concatenation: either operand is a String + if *op == BinOp::Add && (is_string_type(&left_ty) || is_string_type(&right_ty)) { + return self.gen_string_concat(expr); + } + + let promoted = promote_numeric_type(&left_ty, &right_ty); + self.gen_expr(left)?; + self.emit_widen_if_needed(&left_ty, &promoted); + self.gen_expr(right)?; + // Shift ops: right operand is always int, no widening + if !matches!(op, BinOp::Shl | BinOp::Shr | BinOp::Ushr) { + self.emit_widen_if_needed(&right_ty, &promoted); + } + self.emit_typed_binary_op(op, &promoted)?; + Ok(()) + } + CExpr::UnaryOp { op, operand } => { + let ty = self.infer_expr_type(operand); + self.gen_expr(operand)?; + match op { + UnaryOp::Neg => { + if is_long_type(&ty) { + self.emit(Instruction::Lneg); + } else if is_float_type(&ty) { + self.emit(Instruction::Fneg); + } else if is_double_type(&ty) { + self.emit(Instruction::Dneg); + } else { + self.emit(Instruction::Ineg); + } + } + UnaryOp::BitNot => { + // ~x == x ^ -1 + if is_long_type(&ty) { + let cp_idx = self.class_file.get_or_add_long(-1); + self.emit(Instruction::Ldc2W(cp_idx)); + self.emit(Instruction::Lxor); + } else { + self.emit(Instruction::Iconstm1); + self.emit(Instruction::Ixor); + } + } + } + Ok(()) + } + CExpr::Comparison { op, left, right } => { + // Evaluate comparison to 0/1 using branches + let left_ty = self.infer_expr_type(left); + let right_ty = self.infer_expr_type(right); + let promoted = promote_numeric_type(&left_ty, &right_ty); + self.gen_expr(left)?; + self.emit_widen_if_needed(&left_ty, &promoted); + self.gen_expr(right)?; + self.emit_widen_if_needed(&right_ty, &promoted); + let true_label = self.new_label(); + let end_label = self.new_label(); + // end_label has Integer on stack (result of iconst_0 or iconst_1) + self.label_stack_override + .push((end_label, vec![VType::Integer])); + if is_reference_type(&promoted) { + // Reference equality: use if_acmpeq / if_acmpne + let branch = match op { + CompareOp::Eq => Instruction::IfAcmpeq as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::IfAcmpne, + _ => { + return Err(CompileError::CodegenError { + message: "cannot use relational operator on reference types".into(), + }); + } + }; + self.emit_branch(branch, true_label); + } else if is_int_type(&promoted) { + let branch = match op { + CompareOp::Eq => Instruction::IfIcmpeq as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::IfIcmpne, + CompareOp::Lt => Instruction::IfIcmplt, + CompareOp::Le => Instruction::IfIcmple, + CompareOp::Gt => Instruction::IfIcmpgt, + CompareOp::Ge => Instruction::IfIcmpge, + }; + self.emit_branch(branch, true_label); + } else { + // long/float/double: emit compare instruction, then branch on int result + self.emit_typed_compare(&promoted, op); + let branch = match op { + CompareOp::Eq => Instruction::Ifeq as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::Ifne, + CompareOp::Lt => Instruction::Iflt, + CompareOp::Le => Instruction::Ifle, + CompareOp::Gt => Instruction::Ifgt, + CompareOp::Ge => Instruction::Ifge, + }; + self.emit_branch(branch, true_label); + } + self.emit(Instruction::Iconst0); + self.emit_goto(end_label); + self.bind_label(true_label); + self.emit(Instruction::Iconst1); + self.bind_label(end_label); + Ok(()) + } + CExpr::LogicalAnd(left, right) => { + let false_label = self.new_label(); + let end_label = self.new_label(); + // end_label has Integer on stack (result of iconst_0 or iconst_1) + self.label_stack_override + .push((end_label, vec![VType::Integer])); + self.gen_condition(left, false_label, false)?; + self.gen_condition(right, false_label, false)?; + self.emit(Instruction::Iconst1); + self.emit_goto(end_label); + self.bind_label(false_label); + self.emit(Instruction::Iconst0); + self.bind_label(end_label); + Ok(()) + } + CExpr::LogicalOr(left, right) => { + let true_label = self.new_label(); + let false_label = self.new_label(); + let end_label = self.new_label(); + // end_label has Integer on stack (result of iconst_0 or iconst_1) + self.label_stack_override + .push((end_label, vec![VType::Integer])); + self.gen_condition(left, true_label, true)?; + self.gen_condition(right, false_label, false)?; + self.bind_label(true_label); + self.emit(Instruction::Iconst1); + self.emit_goto(end_label); + self.bind_label(false_label); + self.emit(Instruction::Iconst0); + self.bind_label(end_label); + Ok(()) + } + CExpr::LogicalNot(operand) => { + let true_label = self.new_label(); + let end_label = self.new_label(); + // end_label has Integer on stack (result of iconst_0 or iconst_1) + self.label_stack_override + .push((end_label, vec![VType::Integer])); + self.gen_condition(operand, true_label, true)?; + // Condition was false, so !cond is true + self.emit(Instruction::Iconst1); + self.emit_goto(end_label); + self.bind_label(true_label); + // Condition was true, so !cond is false + self.emit(Instruction::Iconst0); + self.bind_label(end_label); + Ok(()) + } + CExpr::Assign { target, value } => { + // Special-case array stores: emit arrayref, index, value, then xastore + if let CExpr::ArrayAccess { array, index } = target.as_ref() { + let array_ty = self.infer_expr_type(array); + let elem_ty = match &array_ty { + TypeName::Array(inner) => inner.as_ref().clone(), + _ => TypeName::Primitive(PrimitiveKind::Int), + }; + self.gen_expr(array)?; + self.gen_expr(index)?; + self.gen_expr(value)?; + // Duplicate the value under [arrayref, index] so that the expression + // result remains on the stack after the array store. This matches Java + // semantics: `counter[0] = x` evaluates to the stored value. + // dup_x2: category-1 value under two category-1 values → value stays on top + // dup2_x2: category-2 value (long/double) under two category-1 values + if type_slot_width(&elem_ty) == 2 { + self.emit(Instruction::Dup2x2); + } else { + self.emit(Instruction::Dupx2); + } + self.emit_array_store(&array_ty); + } else { + self.gen_expr(value)?; + if type_slot_width(&self.infer_expr_type(value)) == 2 { + self.emit(Instruction::Dup2); + } else { + self.emit(Instruction::Dup); + } + self.gen_store_target(target)?; + } + Ok(()) + } + CExpr::CompoundAssign { op, target, value } => { + // Load current, compute, dup, store + let target_ty = self.infer_expr_type(target); + let value_ty = self.infer_expr_type(value); + let promoted = promote_numeric_type(&target_ty, &value_ty); + self.gen_expr(target)?; + self.emit_widen_if_needed(&target_ty, &promoted); + self.gen_expr(value)?; + if !matches!(op, BinOp::Shl | BinOp::Shr | BinOp::Ushr) { + self.emit_widen_if_needed(&value_ty, &promoted); + } + self.emit_typed_binary_op(op, &promoted)?; + // Narrow back to target type if needed (e.g. int += double would need d2i) + if numeric_rank(&promoted) > numeric_rank(&target_ty) { + self.emit_narrow(&promoted, &target_ty); + } + if type_slot_width(&target_ty) == 2 { + self.emit(Instruction::Dup2); + } else { + self.emit(Instruction::Dup); + } + self.gen_store_target(target)?; + Ok(()) + } + CExpr::PreIncrement(operand) => { + if let CExpr::Ident(name) = operand.as_ref() { + let (slot, ty) = + self.locals + .find(name) + .ok_or_else(|| CompileError::CodegenError { + message: format!("undefined variable: {}", name), + })?; + let ty = ty.clone(); + let slot = slot; + if is_int_type(&ty) && slot <= 255 { + self.emit(Instruction::Iinc { + index: slot as u8, + value: 1, + }); + self.emit_load(&ty, slot); + } else { + self.emit_load(&ty, slot); + self.emit_typed_const_one(&ty); + self.emit_typed_binary_op(&BinOp::Add, &ty)?; + if type_slot_width(&ty) == 2 { + self.emit(Instruction::Dup2); + } else { + self.emit(Instruction::Dup); + } + self.emit_store(&ty, slot); + } + Ok(()) + } else { + Err(CompileError::CodegenError { + message: "pre-increment requires simple variable".into(), + }) + } + } + CExpr::PreDecrement(operand) => { + if let CExpr::Ident(name) = operand.as_ref() { + let (slot, ty) = + self.locals + .find(name) + .ok_or_else(|| CompileError::CodegenError { + message: format!("undefined variable: {}", name), + })?; + let ty = ty.clone(); + let slot = slot; + if is_int_type(&ty) && slot <= 255 { + self.emit(Instruction::Iinc { + index: slot as u8, + value: -1, + }); + self.emit_load(&ty, slot); + } else { + self.emit_load(&ty, slot); + self.emit_typed_const_one(&ty); + self.emit_typed_binary_op(&BinOp::Sub, &ty)?; + if type_slot_width(&ty) == 2 { + self.emit(Instruction::Dup2); + } else { + self.emit(Instruction::Dup); + } + self.emit_store(&ty, slot); + } + Ok(()) + } else { + Err(CompileError::CodegenError { + message: "pre-decrement requires simple variable".into(), + }) + } + } + CExpr::PostIncrement(operand) => { + if let CExpr::Ident(name) = operand.as_ref() { + let (slot, ty) = + self.locals + .find(name) + .ok_or_else(|| CompileError::CodegenError { + message: format!("undefined variable: {}", name), + })?; + let ty = ty.clone(); + let slot = slot; + self.emit_load(&ty, slot); + if is_int_type(&ty) && slot <= 255 { + self.emit(Instruction::Iinc { + index: slot as u8, + value: 1, + }); + } else { + if type_slot_width(&ty) == 2 { + self.emit(Instruction::Dup2); + } else { + self.emit(Instruction::Dup); + } + self.emit_typed_const_one(&ty); + self.emit_typed_binary_op(&BinOp::Add, &ty)?; + self.emit_store(&ty, slot); + } + Ok(()) + } else { + Err(CompileError::CodegenError { + message: "post-increment requires simple variable".into(), + }) + } + } + CExpr::PostDecrement(operand) => { + if let CExpr::Ident(name) = operand.as_ref() { + let (slot, ty) = + self.locals + .find(name) + .ok_or_else(|| CompileError::CodegenError { + message: format!("undefined variable: {}", name), + })?; + let ty = ty.clone(); + let slot = slot; + self.emit_load(&ty, slot); + if is_int_type(&ty) && slot <= 255 { + self.emit(Instruction::Iinc { + index: slot as u8, + value: -1, + }); + } else { + if type_slot_width(&ty) == 2 { + self.emit(Instruction::Dup2); + } else { + self.emit(Instruction::Dup); + } + self.emit_typed_const_one(&ty); + self.emit_typed_binary_op(&BinOp::Sub, &ty)?; + self.emit_store(&ty, slot); + } + Ok(()) + } else { + Err(CompileError::CodegenError { + message: "post-decrement requires simple variable".into(), + }) + } + } + CExpr::MethodCall { object, name, args } => { + self.gen_method_call(object.as_deref(), name, args) + } + CExpr::StaticMethodCall { + class_name, + name, + args, + } => self.gen_static_method_call(class_name, name, args), + CExpr::FieldAccess { object, name } => self.gen_field_access(object, name), + CExpr::StaticFieldAccess { class_name, name } => { + self.gen_static_field_access(class_name, name) + } + CExpr::NewObject { class_name, args } => { + let internal = resolve_class_name(class_name); + let class_idx = self.class_file.get_or_add_class(&internal); + self.emit(Instruction::New(class_idx)); + self.emit(Instruction::Dup); + for arg in args { + self.gen_expr(arg)?; + } + // Default constructor descriptor — try to infer from arg count + let descriptor = self.infer_constructor_descriptor(args)?; + let method_idx = + self.class_file + .get_or_add_method_ref(&internal, "", &descriptor); + self.emit(Instruction::Invokespecial(method_idx)); + Ok(()) + } + CExpr::NewArray { element_type, size } => { + self.gen_expr(size)?; + match element_type { + TypeName::Primitive(kind) => { + let atype = match kind { + PrimitiveKind::Boolean => 4, + PrimitiveKind::Char => 5, + PrimitiveKind::Float => 6, + PrimitiveKind::Double => 7, + PrimitiveKind::Byte => 8, + PrimitiveKind::Short => 9, + PrimitiveKind::Int => 10, + PrimitiveKind::Long => 11, + PrimitiveKind::Void => { + return Err(CompileError::CodegenError { + message: "cannot create array of void".into(), + }); + } + }; + self.emit(Instruction::Newarray(atype)); + } + TypeName::Class(name) => { + let internal = resolve_class_name(name); + let class_idx = self.class_file.get_or_add_class(&internal); + self.emit(Instruction::Anewarray(class_idx)); + } + TypeName::Array(_) => { + // Multi-dimensional: create array of arrays + let descriptor = type_name_to_descriptor(element_type); + let class_idx = self.class_file.get_or_add_class(&descriptor); + self.emit(Instruction::Anewarray(class_idx)); + } + } + Ok(()) + } + CExpr::NewMultiArray { + element_type, + dimensions, + } => { + for dim in dimensions { + self.gen_expr(dim)?; + } + let mut desc = String::new(); + for _ in 0..dimensions.len() { + desc.push('['); + } + desc.push_str(&type_name_to_descriptor(element_type)); + let class_idx = self.class_file.get_or_add_class(&desc); + self.emit(Instruction::Multianewarray { + index: class_idx, + dimensions: dimensions.len() as u8, + }); + Ok(()) + } + CExpr::ArrayAccess { array, index } => { + self.gen_expr(array)?; + self.gen_expr(index)?; + let array_ty = self.infer_expr_type(array); + self.emit_array_load(&array_ty); + Ok(()) + } + CExpr::Cast { ty, operand } => { + let src_ty = self.infer_expr_type(operand); + self.gen_expr(operand)?; + match ty { + TypeName::Primitive(kind) => { + let src_rank = numeric_rank(&src_ty); + let _dst_rank = numeric_rank(ty); + // Same rank or both int-like: may still need narrowing + match (src_rank, kind) { + // Source is int-like + (0, PrimitiveKind::Long) => { + self.emit(Instruction::I2l); + } + (0, PrimitiveKind::Float) => { + self.emit(Instruction::I2f); + } + (0, PrimitiveKind::Double) => { + self.emit(Instruction::I2d); + } + (0, PrimitiveKind::Byte) => { + self.emit(Instruction::I2b); + } + (0, PrimitiveKind::Char) => { + self.emit(Instruction::I2c); + } + (0, PrimitiveKind::Short) => { + self.emit(Instruction::I2s); + } + (0, PrimitiveKind::Int) | (0, PrimitiveKind::Boolean) => {} + // Source is long + (1, PrimitiveKind::Int) => { + self.emit(Instruction::L2i); + } + (1, PrimitiveKind::Float) => { + self.emit(Instruction::L2f); + } + (1, PrimitiveKind::Double) => { + self.emit(Instruction::L2d); + } + (1, PrimitiveKind::Byte) => { + self.emit(Instruction::L2i); + self.emit(Instruction::I2b); + } + (1, PrimitiveKind::Char) => { + self.emit(Instruction::L2i); + self.emit(Instruction::I2c); + } + (1, PrimitiveKind::Short) => { + self.emit(Instruction::L2i); + self.emit(Instruction::I2s); + } + (1, PrimitiveKind::Long) => {} + // Source is float + (2, PrimitiveKind::Int) => { + self.emit(Instruction::F2i); + } + (2, PrimitiveKind::Long) => { + self.emit(Instruction::F2l); + } + (2, PrimitiveKind::Double) => { + self.emit(Instruction::F2d); + } + (2, PrimitiveKind::Byte) => { + self.emit(Instruction::F2i); + self.emit(Instruction::I2b); + } + (2, PrimitiveKind::Char) => { + self.emit(Instruction::F2i); + self.emit(Instruction::I2c); + } + (2, PrimitiveKind::Short) => { + self.emit(Instruction::F2i); + self.emit(Instruction::I2s); + } + (2, PrimitiveKind::Float) => {} + // Source is double + (3, PrimitiveKind::Int) => { + self.emit(Instruction::D2i); + } + (3, PrimitiveKind::Long) => { + self.emit(Instruction::D2l); + } + (3, PrimitiveKind::Float) => { + self.emit(Instruction::D2f); + } + (3, PrimitiveKind::Byte) => { + self.emit(Instruction::D2i); + self.emit(Instruction::I2b); + } + (3, PrimitiveKind::Char) => { + self.emit(Instruction::D2i); + self.emit(Instruction::I2c); + } + (3, PrimitiveKind::Short) => { + self.emit(Instruction::D2i); + self.emit(Instruction::I2s); + } + (3, PrimitiveKind::Double) => {} + (_, PrimitiveKind::Void) => { + return Err(CompileError::CodegenError { + message: "cannot cast to void".into(), + }); + } + _ => {} // same type, no-op + } + Ok(()) + } + TypeName::Class(name) => { + let internal = resolve_class_name(name); + let class_idx = self.class_file.get_or_add_class(&internal); + self.emit(Instruction::Checkcast(class_idx)); + Ok(()) + } + TypeName::Array(_) => { + let descriptor = type_name_to_descriptor(ty); + let class_idx = self.class_file.get_or_add_class(&descriptor); + self.emit(Instruction::Checkcast(class_idx)); + Ok(()) + } + } + } + CExpr::Instanceof { operand, ty } => { + self.gen_expr(operand)?; + match ty { + TypeName::Class(name) => { + let internal = resolve_class_name(name); + let class_idx = self.class_file.get_or_add_class(&internal); + self.emit(Instruction::Instanceof(class_idx)); + } + TypeName::Array(_) => { + let descriptor = type_name_to_descriptor(ty); + let class_idx = self.class_file.get_or_add_class(&descriptor); + self.emit(Instruction::Instanceof(class_idx)); + } + _ => { + return Err(CompileError::CodegenError { + message: "instanceof requires class or array type".into(), + }); + } + } + Ok(()) + } + CExpr::Ternary { + condition, + then_expr, + else_expr, + } => { + let false_label = self.new_label(); + let end_label = self.new_label(); + // Both branches push a result value before reaching end_label + let result_vtype = + type_name_to_vtype_resolved(&self.infer_expr_type(then_expr), self.class_file); + let result_stack = vec![result_vtype]; + self.label_stack_override.push((false_label, Vec::new())); + self.label_stack_override.push((end_label, result_stack)); + self.gen_condition(condition, false_label, false)?; + self.gen_expr(then_expr)?; + self.emit_goto(end_label); + self.bind_label(false_label); + self.gen_expr(else_expr)?; + self.bind_label(end_label); + Ok(()) + } + CExpr::SwitchExpr { + expr, + cases, + default_expr, + } => self.gen_switch_expr(expr, cases, default_expr), + CExpr::Lambda { params, body } => self.gen_lambda(params, body), + CExpr::MethodRef { + class_name, + method_name, + } => self.gen_method_ref(class_name, method_name), + } + } + + // --- Condition codegen (emit direct branch instructions) --- + + /// Generate condition code. If `jump_on_true`, jumps to `target_label` when condition is true. + /// Otherwise, jumps to `target_label` when condition is false. + fn gen_condition( + &mut self, + expr: &CExpr, + target_label: usize, + jump_on_true: bool, + ) -> Result<(), CompileError> { + match expr { + CExpr::Comparison { op, left, right } => { + // Check for null comparisons + if matches!(right.as_ref(), CExpr::NullLiteral) { + self.gen_expr(left)?; + let branch = if jump_on_true { + match op { + CompareOp::Eq => Instruction::Ifnull as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::Ifnonnull, + _ => { + return Err(CompileError::CodegenError { + message: "cannot compare null with relational operator".into(), + }); + } + } + } else { + match op { + CompareOp::Eq => Instruction::Ifnonnull as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::Ifnull, + _ => { + return Err(CompileError::CodegenError { + message: "cannot compare null with relational operator".into(), + }); + } + } + }; + self.emit_branch(branch, target_label); + return Ok(()); + } + if matches!(left.as_ref(), CExpr::NullLiteral) { + self.gen_expr(right)?; + let branch = if jump_on_true { + match op { + CompareOp::Eq => Instruction::Ifnull as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::Ifnonnull, + _ => { + return Err(CompileError::CodegenError { + message: "cannot compare null with relational operator".into(), + }); + } + } + } else { + match op { + CompareOp::Eq => Instruction::Ifnonnull as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::Ifnull, + _ => { + return Err(CompileError::CodegenError { + message: "cannot compare null with relational operator".into(), + }); + } + } + }; + self.emit_branch(branch, target_label); + return Ok(()); + } + + let left_ty = self.infer_expr_type(left); + let right_ty = self.infer_expr_type(right); + let promoted = promote_numeric_type(&left_ty, &right_ty); + self.gen_expr(left)?; + self.emit_widen_if_needed(&left_ty, &promoted); + self.gen_expr(right)?; + self.emit_widen_if_needed(&right_ty, &promoted); + + if is_reference_type(&promoted) { + // Reference equality: use if_acmpeq / if_acmpne + let branch = match (op, jump_on_true) { + (CompareOp::Eq, true) | (CompareOp::Ne, false) => { + Instruction::IfAcmpeq as fn(i16) -> Instruction + } + (CompareOp::Ne, true) | (CompareOp::Eq, false) => { + Instruction::IfAcmpne as fn(i16) -> Instruction + } + _ => { + return Err(CompileError::CodegenError { + message: "cannot use relational operator on reference types".into(), + }); + } + }; + self.emit_branch(branch, target_label); + } else if is_int_type(&promoted) { + let branch = if jump_on_true { + match op { + CompareOp::Eq => Instruction::IfIcmpeq as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::IfIcmpne, + CompareOp::Lt => Instruction::IfIcmplt, + CompareOp::Le => Instruction::IfIcmple, + CompareOp::Gt => Instruction::IfIcmpgt, + CompareOp::Ge => Instruction::IfIcmpge, + } + } else { + match op { + CompareOp::Eq => Instruction::IfIcmpne as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::IfIcmpeq, + CompareOp::Lt => Instruction::IfIcmpge, + CompareOp::Le => Instruction::IfIcmpgt, + CompareOp::Gt => Instruction::IfIcmple, + CompareOp::Ge => Instruction::IfIcmplt, + } + }; + self.emit_branch(branch, target_label); + } else { + // long/float/double: compare instruction reduces to int, then branch + self.emit_typed_compare(&promoted, op); + let branch = if jump_on_true { + match op { + CompareOp::Eq => Instruction::Ifeq as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::Ifne, + CompareOp::Lt => Instruction::Iflt, + CompareOp::Le => Instruction::Ifle, + CompareOp::Gt => Instruction::Ifgt, + CompareOp::Ge => Instruction::Ifge, + } + } else { + match op { + CompareOp::Eq => Instruction::Ifne as fn(i16) -> Instruction, + CompareOp::Ne => Instruction::Ifeq, + CompareOp::Lt => Instruction::Ifge, + CompareOp::Le => Instruction::Ifgt, + CompareOp::Gt => Instruction::Ifle, + CompareOp::Ge => Instruction::Iflt, + } + }; + self.emit_branch(branch, target_label); + } + Ok(()) + } + CExpr::LogicalAnd(left, right) => { + if jump_on_true { + // a && b is true: both must be true + let skip = self.new_label(); + self.gen_condition(left, skip, false)?; + self.gen_condition(right, target_label, true)?; + self.bind_label(skip); + } else { + // a && b is false: either is false + self.gen_condition(left, target_label, false)?; + self.gen_condition(right, target_label, false)?; + } + Ok(()) + } + CExpr::LogicalOr(left, right) => { + if jump_on_true { + // a || b is true: either is true + self.gen_condition(left, target_label, true)?; + self.gen_condition(right, target_label, true)?; + } else { + // a || b is false: both must be false + let skip = self.new_label(); + self.gen_condition(left, skip, true)?; + self.gen_condition(right, target_label, false)?; + self.bind_label(skip); + } + Ok(()) + } + CExpr::LogicalNot(operand) => self.gen_condition(operand, target_label, !jump_on_true), + CExpr::BoolLiteral(true) => { + if jump_on_true { + self.emit_goto(target_label); + } + Ok(()) + } + CExpr::BoolLiteral(false) => { + if !jump_on_true { + self.emit_goto(target_label); + } + Ok(()) + } + _ => { + // Generic: evaluate to int, branch on 0/non-0 + self.gen_expr(expr)?; + let branch = if jump_on_true { + Instruction::Ifne as fn(i16) -> Instruction + } else { + Instruction::Ifeq as fn(i16) -> Instruction + }; + self.emit_branch(branch, target_label); + Ok(()) + } + } + } + + // --- Switch codegen --- + + fn gen_switch( + &mut self, + expr: &CExpr, + cases: &[SwitchCase], + default_body: Option<&[CStmt]>, + ) -> Result<(), CompileError> { + self.gen_expr(expr)?; + + let end_label = self.new_label(); + let default_label = self.new_label(); + + // Collect all (value, case_index) pairs + let mut value_to_case: Vec<(i32, usize)> = Vec::new(); + for (case_idx, case) in cases.iter().enumerate() { + for &v in &case.values { + value_to_case.push((v as i32, case_idx)); + } + } + value_to_case.sort_by_key(|&(v, _)| v); + + // Create labels for each case body + let case_labels: Vec = cases.iter().map(|_| self.new_label()).collect(); + + // Override all branch-target labels with pre-switch locals + let pre_case_locals = self.locals.current_locals_vtypes(); + for &label in &case_labels { + self.label_locals_override + .push((label, pre_case_locals.clone())); + } + self.label_locals_override + .push((default_label, pre_case_locals.clone())); + self.label_locals_override + .push((end_label, pre_case_locals)); + + // Decide tableswitch vs lookupswitch + let use_table = if value_to_case.is_empty() { + false + } else { + let low = value_to_case.first().unwrap().0; + let high = value_to_case.last().unwrap().0; + let range = (high as i64 - low as i64 + 1) as usize; + range <= 2 * value_to_case.len() + }; + + if use_table && !value_to_case.is_empty() { + let low = value_to_case.first().unwrap().0; + let high = value_to_case.last().unwrap().0; + + // Build offset labels array: for each index in [low..=high], map to case label or default + let mut offset_labels: Vec = Vec::new(); + let mut val_idx = 0; + for v in low..=high { + if val_idx < value_to_case.len() && value_to_case[val_idx].0 == v { + offset_labels.push(case_labels[value_to_case[val_idx].1]); + val_idx += 1; + } else { + offset_labels.push(default_label); + } + } + + // Emit placeholder tableswitch + let placeholder = Instruction::Tableswitch { + default: 0, + low, + high, + offsets: vec![0i32; offset_labels.len()], + }; + let instr_idx = self.emit(placeholder); + self.switch_patches.push(SwitchPatch { + instr_idx, + kind: SwitchPatchKind::Table { + low, + high, + case_labels: offset_labels, + default_label, + }, + }); + } else { + // Lookupswitch + let pair_labels: Vec<(i32, usize)> = value_to_case + .iter() + .map(|&(v, case_idx)| (v, case_labels[case_idx])) + .collect(); + + let placeholder = Instruction::Lookupswitch { + default: 0, + npairs: pair_labels.len() as u32, + pairs: pair_labels.iter().map(|&(v, _)| (v, 0i32)).collect(), + }; + let instr_idx = self.emit(placeholder); + self.switch_patches.push(SwitchPatch { + instr_idx, + kind: SwitchPatchKind::Lookup { + pairs: pair_labels, + default_label, + }, + }); + } + + // Push breakable context + self.breakable_stack.push(BreakableContext { + break_label: end_label, + is_loop: false, + continue_label: None, + }); + + // Emit case bodies + for (i, case) in cases.iter().enumerate() { + self.bind_label(case_labels[i]); + for s in &case.body { + self.gen_stmt(s)?; + } + } + + // Emit default body + self.bind_label(default_label); + if let Some(body) = default_body { + for s in body { + self.gen_stmt(s)?; + } + } + + self.bind_label(end_label); + self.breakable_stack.pop(); + Ok(()) + } + + fn gen_switch_expr( + &mut self, + expr: &CExpr, + cases: &[SwitchExprCase], + default_expr: &CExpr, + ) -> Result<(), CompileError> { + self.gen_expr(expr)?; + + let end_label = self.new_label(); + let default_label = self.new_label(); + + // Collect all (value, case_index) pairs + let mut value_to_case: Vec<(i32, usize)> = Vec::new(); + for (case_idx, case) in cases.iter().enumerate() { + for &v in &case.values { + value_to_case.push((v as i32, case_idx)); + } + } + value_to_case.sort_by_key(|&(v, _)| v); + + let case_labels: Vec = cases.iter().map(|_| self.new_label()).collect(); + + // Override all branch-target labels with pre-switch locals + let pre_case_locals = self.locals.current_locals_vtypes(); + for &label in &case_labels { + self.label_locals_override + .push((label, pre_case_locals.clone())); + } + self.label_locals_override + .push((default_label, pre_case_locals.clone())); + self.label_locals_override + .push((end_label, pre_case_locals)); + + // end_label has the switch result on the stack + let result_vtype = + type_name_to_vtype_resolved(&self.infer_expr_type(&cases[0].expr), self.class_file); + self.label_stack_override + .push((end_label, vec![result_vtype])); + + // Decide tableswitch vs lookupswitch + let use_table = if value_to_case.is_empty() { + false + } else { + let low = value_to_case.first().unwrap().0; + let high = value_to_case.last().unwrap().0; + let range = (high as i64 - low as i64 + 1) as usize; + range <= 2 * value_to_case.len() + }; + + if use_table && !value_to_case.is_empty() { + let low = value_to_case.first().unwrap().0; + let high = value_to_case.last().unwrap().0; + + let mut offset_labels: Vec = Vec::new(); + let mut val_idx = 0; + for v in low..=high { + if val_idx < value_to_case.len() && value_to_case[val_idx].0 == v { + offset_labels.push(case_labels[value_to_case[val_idx].1]); + val_idx += 1; + } else { + offset_labels.push(default_label); + } + } + + let placeholder = Instruction::Tableswitch { + default: 0, + low, + high, + offsets: vec![0i32; offset_labels.len()], + }; + let instr_idx = self.emit(placeholder); + self.switch_patches.push(SwitchPatch { + instr_idx, + kind: SwitchPatchKind::Table { + low, + high, + case_labels: offset_labels, + default_label, + }, + }); + } else { + let pair_labels: Vec<(i32, usize)> = value_to_case + .iter() + .map(|&(v, case_idx)| (v, case_labels[case_idx])) + .collect(); + + let placeholder = Instruction::Lookupswitch { + default: 0, + npairs: pair_labels.len() as u32, + pairs: pair_labels.iter().map(|&(v, _)| (v, 0i32)).collect(), + }; + let instr_idx = self.emit(placeholder); + self.switch_patches.push(SwitchPatch { + instr_idx, + kind: SwitchPatchKind::Lookup { + pairs: pair_labels, + default_label, + }, + }); + } + + // Emit case expression bodies — each pushes a value then jumps to end + for (i, case) in cases.iter().enumerate() { + self.bind_label(case_labels[i]); + self.gen_expr(&case.expr)?; + self.emit_goto(end_label); + } + + // Default + self.bind_label(default_label); + self.gen_expr(default_expr)?; + + self.bind_label(end_label); + Ok(()) + } + + // --- Lambda codegen --- + + fn gen_lambda( + &mut self, + params: &[LambdaParam], + body: &LambdaBody, + ) -> Result<(), CompileError> { + // Generate a synthetic static method name + let lambda_idx = self.class_file.methods.len(); + let lambda_name = format!("lambda${}", lambda_idx); + + // Build lambda method descriptor from params + // All typed params contribute to descriptor; untyped params default to Object + let mut param_descs = Vec::new(); + for p in params { + let desc = match &p.ty { + Some(ty) => type_name_to_descriptor(ty), + None => "Ljava/lang/Object;".to_string(), + }; + param_descs.push(desc); + } + + // Infer return type from body + let return_desc = match body { + LambdaBody::Expr(expr) => { + let ty = self.infer_expr_type(expr); + type_name_to_descriptor(&ty) + } + LambdaBody::Block(_) => "V".to_string(), // Block lambdas default to void + }; + + let lambda_descriptor = format!("({}){}", param_descs.join(""), return_desc); + + // Build the SAM method descriptor (functional interface method type) + // This is the same as the lambda descriptor for simple cases + let sam_descriptor = lambda_descriptor.clone(); + + // Create the synthetic method body + let stmts: Vec = match body { + LambdaBody::Expr(expr) => { + if return_desc == "V" { + vec![CStmt::ExprStmt(expr.as_ref().clone())] + } else { + vec![CStmt::Return(Some(expr.as_ref().clone()))] + } + } + LambdaBody::Block(stmts) => stmts.clone(), + }; + + // Generate bytecode for the synthetic method + let mut lambda_codegen = CodeGenerator::new( + self.class_file, + true, // lambda methods are static + &lambda_descriptor, + &[], // lambda params don't have debug names + )?; + lambda_codegen.generate_body(&stmts)?; + let generated = lambda_codegen.finish()?; + + // Create Code attribute + let code_name_idx = self.class_file.get_or_add_utf8("Code"); + let exception_table_length = generated.exception_table.len() as u16; + let code_attr = CodeAttribute { + max_stack: generated.max_stack, + max_locals: generated.max_locals, + code_length: 0, + code: generated.instructions, + exception_table_length, + exception_table: generated.exception_table, + attributes_count: 0, + attributes: Vec::new(), + }; + + let mut attr_info = AttributeInfo { + attribute_name_index: code_name_idx, + attribute_length: 0, + info: vec![], + info_parsed: Some(AttributeInfoVariant::Code(code_attr)), + }; + attr_info + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed for lambda Code failed: {}", e), + })?; + + // Create the MethodInfo for the synthetic method + let name_idx = self.class_file.get_or_add_utf8(&lambda_name); + let desc_idx = self.class_file.get_or_add_utf8(&lambda_descriptor); + + let method_info = MethodInfo { + access_flags: MethodAccessFlags::PRIVATE + | MethodAccessFlags::STATIC + | MethodAccessFlags::SYNTHETIC, + name_index: name_idx, + descriptor_index: desc_idx, + attributes_count: 1, + attributes: vec![attr_info], + }; + self.class_file.methods.push(method_info); + self.class_file.sync_counts(); + + // Set up bootstrap method for LambdaMetafactory.metafactory + let bootstrap_idx = self.get_or_add_lambda_bootstrap()?; + + // Build the invokedynamic constant + // The invokedynamic call site produces the functional interface instance. + // Use the correct SAM name for the guessed functional interface. + let fi_class = self.guess_functional_interface(¶m_descs, &return_desc); + let invoke_name = match fi_class.as_str() { + "java/lang/Runnable" => "run", + "java/util/function/Supplier" => "get", + "java/util/function/Consumer" => "accept", + "java/util/function/Predicate" => "test", + _ => "apply", + }; + let invoke_desc = format!("()L{};", fi_class); + let indy_idx = + self.class_file + .get_or_add_invoke_dynamic(bootstrap_idx, invoke_name, &invoke_desc); + + // Add bootstrap method arguments: + // 1. MethodType: SAM method type (e.g., "()V" for Runnable.run) + // 2. MethodHandle: impl method (our synthetic lambda method) + // 3. MethodType: instantiated method type (same as SAM type for non-generic) + let this_class_name = self.get_this_class_name()?; + let impl_method_ref = self.class_file.get_or_add_method_ref( + &this_class_name, + &lambda_name, + &lambda_descriptor, + ); + let impl_handle = self.class_file.get_or_add_method_handle( + 6, // REF_invokeStatic + impl_method_ref, + ); + let sam_method_type = self.class_file.get_or_add_method_type(&sam_descriptor); + let instantiated_method_type = self.class_file.get_or_add_method_type(&sam_descriptor); + + // Update the bootstrap method entry with the arguments + self.update_bootstrap_args( + bootstrap_idx, + &[sam_method_type, impl_handle, instantiated_method_type], + )?; + + // Emit the invokedynamic instruction + self.emit(Instruction::Invokedynamic { + index: indy_idx, + filler: 0, + }); + + Ok(()) + } + + fn gen_method_ref(&mut self, class_name: &str, method_name: &str) -> Result<(), CompileError> { + let internal_class = resolve_class_name(class_name); + + // Resolve the method descriptor from the constant pool or well-known methods + let impl_descriptor = self.find_method_ref_descriptor(&internal_class, method_name)?; + let (params, ret) = parse_method_descriptor(&impl_descriptor).ok_or_else(|| { + CompileError::CodegenError { + message: format!( + "invalid method descriptor for {}::{}: {}", + class_name, method_name, impl_descriptor + ), + } + })?; + + // Build erased SAM descriptor: all reference params → Object, primitives stay + let erased_params: Vec = params.iter().map(erase_to_object).collect(); + let erased_ret = erase_to_object(&ret); + let sam_desc = format!("({}){}", erased_params.join(""), erased_ret); + + // Determine functional interface from param/return types + let param_descs: Vec = params.iter().map(|p| p.to_descriptor()).collect(); + let ret_desc = ret.to_descriptor(); + let fi_class = self.guess_functional_interface(¶m_descs, &ret_desc); + let sam_name = match fi_class.as_str() { + "java/lang/Runnable" => "run", + "java/util/function/Consumer" => "accept", + "java/util/function/Predicate" => "test", + "java/util/function/Supplier" => "get", + _ => "apply", + }; + + // invokedynamic descriptor: () -> FunctionalInterface + let indy_desc = format!("()L{};", fi_class); + + let bootstrap_idx = self.get_or_add_lambda_bootstrap()?; + + let indy_idx = + self.class_file + .get_or_add_invoke_dynamic(bootstrap_idx, sam_name, &indy_desc); + + let method_ref = + self.class_file + .get_or_add_method_ref(&internal_class, method_name, &impl_descriptor); + let impl_handle = self.class_file.get_or_add_method_handle(6, method_ref); + let sam_type = self.class_file.get_or_add_method_type(&sam_desc); + let inst_type = self.class_file.get_or_add_method_type(&impl_descriptor); + + self.update_bootstrap_args(bootstrap_idx, &[sam_type, impl_handle, inst_type])?; + + self.emit(Instruction::Invokedynamic { + index: indy_idx, + filler: 0, + }); + + Ok(()) + } + + /// Resolve a method descriptor for a method reference (Class::method). + /// Searches the constant pool first, then falls back to well-known methods. + fn find_method_ref_descriptor( + &self, + internal_class: &str, + method_name: &str, + ) -> Result { + use crate::constant_info::ConstantInfo; + let pool = &self.class_file.const_pool; + + // Search constant pool for MethodRef or InterfaceMethodRef matching class + name + for entry in pool.iter() { + let (class_idx, nat_idx) = match entry { + ConstantInfo::MethodRef(r) => (r.class_index, r.name_and_type_index), + ConstantInfo::InterfaceMethodRef(r) => (r.class_index, r.name_and_type_index), + _ => continue, + }; + // Check class name matches + if let Some(ConstantInfo::Class(cls)) = pool.get((class_idx - 1) as usize) + && let Some(cls_name) = self.class_file.get_utf8(cls.name_index) + && cls_name != internal_class + { + continue; + } + // Check method name and get descriptor + if let Some(ConstantInfo::NameAndType(nat)) = pool.get((nat_idx - 1) as usize) + && let Some(name) = self.class_file.get_utf8(nat.name_index) + && name == method_name + && let Some(desc) = self.class_file.get_utf8(nat.descriptor_index) + { + return Ok(desc.to_string()); + } + } + + // Well-known method descriptors + match (internal_class, method_name) { + ("java/lang/String", "valueOf") => Ok("(Ljava/lang/Object;)Ljava/lang/String;".into()), + ("java/lang/Integer", "parseInt") => Ok("(Ljava/lang/String;)I".into()), + ("java/lang/Integer", "valueOf") => Ok("(I)Ljava/lang/Integer;".into()), + ("java/lang/Long", "parseLong") => Ok("(Ljava/lang/String;)J".into()), + ("java/lang/Long", "valueOf") => Ok("(J)Ljava/lang/Long;".into()), + ("java/lang/Double", "parseDouble") => Ok("(Ljava/lang/String;)D".into()), + ("java/lang/Double", "valueOf") => Ok("(D)Ljava/lang/Double;".into()), + ("java/lang/Boolean", "parseBoolean") => Ok("(Ljava/lang/String;)Z".into()), + ("java/lang/System", "exit") => Ok("(I)V".into()), + _ => Err(CompileError::CodegenError { + message: format!( + "cannot resolve method descriptor for {}::{}; \ + ensure the method is called elsewhere so its descriptor is in the constant pool", + internal_class, method_name + ), + }), + } + } + + /// Guess the functional interface class based on parameter/return types. + fn guess_functional_interface(&self, param_descs: &[String], return_desc: &str) -> String { + match (param_descs.len(), return_desc) { + (0, "V") => "java/lang/Runnable".into(), + (0, _) => "java/util/function/Supplier".into(), + (1, "V") => "java/util/function/Consumer".into(), + (1, "Z") => "java/util/function/Predicate".into(), + (1, _) => "java/util/function/Function".into(), + (2, _) => "java/util/function/BiFunction".into(), + _ => "java/lang/Runnable".into(), + } + } + + /// Find or create the bootstrap method entry for LambdaMetafactory.metafactory. + fn get_or_add_lambda_bootstrap(&mut self) -> Result { + // Build the method handle for LambdaMetafactory.metafactory + let metafactory_ref = self.class_file.get_or_add_method_ref( + "java/lang/invoke/LambdaMetafactory", + "metafactory", + "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;", + ); + let metafactory_handle = self.class_file.get_or_add_method_handle( + 6, // REF_invokeStatic + metafactory_ref, + ); + + // Find existing BootstrapMethods attribute on the class, or create one + let bsm_attr_idx = self.class_file.attributes.iter().position(|a| { + matches!( + &a.info_parsed, + Some(AttributeInfoVariant::BootstrapMethods(_)) + ) + }); + + if let Some(idx) = bsm_attr_idx { + // Each lambda/method-reference needs its own bootstrap method entry. + // Do NOT reuse entries — different lambdas use different impl method handles. + + // Add new bootstrap method entry + let new_bm = BootstrapMethod { + bootstrap_method_ref: metafactory_handle, + num_bootstrap_arguments: 0, + bootstrap_arguments: Vec::new(), + }; + + if let Some(AttributeInfoVariant::BootstrapMethods(bsm)) = + &mut self.class_file.attributes[idx].info_parsed + { + let new_idx = bsm.bootstrap_methods.len() as u16; + bsm.bootstrap_methods.push(new_bm); + bsm.num_bootstrap_methods = bsm.bootstrap_methods.len() as u16; + return Ok(new_idx); + } + unreachable!() + } else { + // Create new BootstrapMethods attribute + let name_idx = self.class_file.get_or_add_utf8("BootstrapMethods"); + let bsm_attr = BootstrapMethodsAttribute { + num_bootstrap_methods: 1, + bootstrap_methods: vec![BootstrapMethod { + bootstrap_method_ref: metafactory_handle, + num_bootstrap_arguments: 0, + bootstrap_arguments: Vec::new(), + }], + }; + let mut attr_info = AttributeInfo { + attribute_name_index: name_idx, + attribute_length: 0, + info: vec![], + info_parsed: Some(AttributeInfoVariant::BootstrapMethods(bsm_attr)), + }; + attr_info + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed for BootstrapMethods failed: {}", e), + })?; + self.class_file.attributes.push(attr_info); + self.class_file.sync_counts(); + Ok(0) + } + } + + /// Update bootstrap method arguments for a specific bootstrap method index. + fn update_bootstrap_args( + &mut self, + bootstrap_idx: u16, + args: &[u16], + ) -> Result<(), CompileError> { + let bsm_attr_idx = self + .class_file + .attributes + .iter() + .position(|a| { + matches!( + &a.info_parsed, + Some(AttributeInfoVariant::BootstrapMethods(_)) + ) + }) + .ok_or_else(|| CompileError::CodegenError { + message: "BootstrapMethods attribute not found".into(), + })?; + + if let Some(AttributeInfoVariant::BootstrapMethods(bsm)) = + &mut self.class_file.attributes[bsm_attr_idx].info_parsed + { + let bm = &mut bsm.bootstrap_methods[bootstrap_idx as usize]; + bm.bootstrap_arguments = args.to_vec(); + bm.num_bootstrap_arguments = args.len() as u16; + } + + // Re-sync the attribute + self.class_file.attributes[bsm_attr_idx] + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed for BootstrapMethods failed: {}", e), + })?; + self.class_file.sync_counts(); + + Ok(()) + } + + // --- For-each codegen --- + + fn gen_foreach( + &mut self, + element_type: &TypeName, + var_name: &str, + iterable: &CExpr, + body: &[CStmt], + ) -> Result<(), CompileError> { + let iterable_ty = self.infer_expr_type(iterable); + + if matches!(iterable_ty, TypeName::Array(_)) { + self.gen_foreach_array(element_type, var_name, iterable, body) + } else { + self.gen_foreach_iterable(element_type, var_name, iterable, body) + } + } + + fn gen_foreach_array( + &mut self, + element_type: &TypeName, + var_name: &str, + iterable: &CExpr, + body: &[CStmt], + ) -> Result<(), CompileError> { + let array_ty = self.infer_expr_type(iterable); + + // Save allocator state before for-each internal locals so they don't leak + let saved_locals = self.locals.save(); + + // Allocate temp slots (but NOT the loop variable yet — it must come after + // the loop-top label so the stack map frame doesn't include it on the first entry) + let arr_vtype = type_name_to_vtype_resolved(&array_ty, self.class_file); + let arr_slot = self + .locals + .allocate_with_vtype("__foreach_arr", &array_ty, arr_vtype); + let len_slot = self + .locals + .allocate("__foreach_len", &TypeName::Primitive(PrimitiveKind::Int)); + let idx_slot = self + .locals + .allocate("__foreach_idx", &TypeName::Primitive(PrimitiveKind::Int)); + + // Evaluate iterable, store array ref + self.gen_expr(iterable)?; + self.emit_store(&array_ty, arr_slot); + + // arraylength → len + self.emit_load(&array_ty, arr_slot); + self.emit(Instruction::Arraylength); + self.emit_store(&TypeName::Primitive(PrimitiveKind::Int), len_slot); + + // i = 0 + self.emit(Instruction::Iconst0); + self.emit_store(&TypeName::Primitive(PrimitiveKind::Int), idx_slot); + + // Loop labels + let top_label = self.new_label(); + let update_label = self.new_label(); + let end_label = self.new_label(); + + // Override end_label frame to use pre-loop-variable locals + let pre_loop_locals = self.locals.current_locals_vtypes(); + self.label_locals_override + .push((end_label, pre_loop_locals)); + + self.breakable_stack.push(BreakableContext { + break_label: end_label, + is_loop: true, + continue_label: Some(update_label), + }); + + // top: if (i >= len) goto end + self.bind_label(top_label); + self.emit_load(&TypeName::Primitive(PrimitiveKind::Int), idx_slot); + self.emit_load(&TypeName::Primitive(PrimitiveKind::Int), len_slot); + self.emit_branch(Instruction::IfIcmpge, end_label); + + // Allocate loop variable AFTER the loop-top frame is recorded + let elem_vtype = type_name_to_vtype_resolved(element_type, self.class_file); + let var_slot = self + .locals + .allocate_with_vtype(var_name, element_type, elem_vtype); + + // var = arr[i] + self.emit_load(&array_ty, arr_slot); + self.emit_load(&TypeName::Primitive(PrimitiveKind::Int), idx_slot); + self.emit_array_load(&array_ty); + self.emit_store(element_type, var_slot); + + // body + for stmt in body { + self.gen_stmt(stmt)?; + } + + // update: i++ + self.bind_label(update_label); + if idx_slot <= 255 { + self.emit(Instruction::Iinc { + index: idx_slot as u8, + value: 1, + }); + } else { + self.emit_load(&TypeName::Primitive(PrimitiveKind::Int), idx_slot); + self.emit(Instruction::Iconst1); + self.emit(Instruction::Iadd); + self.emit_store(&TypeName::Primitive(PrimitiveKind::Int), idx_slot); + } + self.emit_goto(top_label); + + // Restore allocator so for-each internal locals don't leak + self.locals.restore(saved_locals); + self.bind_label(end_label); + self.breakable_stack.pop(); + Ok(()) + } + + fn gen_foreach_iterable( + &mut self, + element_type: &TypeName, + var_name: &str, + iterable: &CExpr, + body: &[CStmt], + ) -> Result<(), CompileError> { + // Save allocator state before for-each internal locals so they don't leak + let saved_locals = self.locals.save(); + + let iter_ty = TypeName::Class("java/util/Iterator".into()); + let iter_vtype = type_name_to_vtype_resolved(&iter_ty, self.class_file); + let iter_slot = self + .locals + .allocate_with_vtype("__foreach_iter", &iter_ty, iter_vtype); + + // iterable.iterator() → iter_slot + self.gen_expr(iterable)?; + let iterator_idx = self.class_file.get_or_add_interface_method_ref( + "java/lang/Iterable", + "iterator", + "()Ljava/util/Iterator;", + ); + self.emit(Instruction::Invokeinterface { + index: iterator_idx, + count: 1, + filler: 0, + }); + self.emit_store(&iter_ty, iter_slot); + + let top_label = self.new_label(); + let end_label = self.new_label(); + + // Override end_label frame to use pre-loop-variable locals + let pre_loop_locals = self.locals.current_locals_vtypes(); + self.label_locals_override + .push((end_label, pre_loop_locals)); + + self.breakable_stack.push(BreakableContext { + break_label: end_label, + is_loop: true, + continue_label: Some(top_label), + }); + + // top: if (!iter.hasNext()) goto end + self.bind_label(top_label); + self.emit_load(&iter_ty, iter_slot); + let has_next_idx = + self.class_file + .get_or_add_interface_method_ref("java/util/Iterator", "hasNext", "()Z"); + self.emit(Instruction::Invokeinterface { + index: has_next_idx, + count: 1, + filler: 0, + }); + self.emit_branch(Instruction::Ifeq, end_label); + + // Allocate loop variable AFTER the loop-top frame is recorded + let elem_vtype = type_name_to_vtype_resolved(element_type, self.class_file); + let var_slot = self + .locals + .allocate_with_vtype(var_name, element_type, elem_vtype); + + // var = (ElementType) iter.next() + self.emit_load(&iter_ty, iter_slot); + let next_idx = self.class_file.get_or_add_interface_method_ref( + "java/util/Iterator", + "next", + "()Ljava/lang/Object;", + ); + self.emit(Instruction::Invokeinterface { + index: next_idx, + count: 1, + filler: 0, + }); + // Checkcast if element type is not Object + if let TypeName::Class(name) = element_type + && name != "Object" + && name != "java/lang/Object" + && name != "java.lang.Object" + { + let internal = resolve_class_name(name); + let class_idx = self.class_file.get_or_add_class(&internal); + self.emit(Instruction::Checkcast(class_idx)); + } + self.emit_store(element_type, var_slot); + + // body + for stmt in body { + self.gen_stmt(stmt)?; + } + self.emit_goto(top_label); + + // Restore allocator so for-each internal locals don't leak + self.locals.restore(saved_locals); + self.bind_label(end_label); + self.breakable_stack.pop(); + Ok(()) + } + + // --- Try-catch codegen --- + + fn gen_try_catch( + &mut self, + try_body: &[CStmt], + catches: &[CatchClause], + finally_body: Option<&[CStmt]>, + ) -> Result<(), CompileError> { + let try_start = self.new_label(); + let try_end = self.new_label(); + let after_all = self.new_label(); + + // Capture locals at try-start for exception handler frames and merge point + let locals_at_try_start = self.locals.current_locals_vtypes(); + // The after_all merge point must use try-start locals, since the try-body exit path + // doesn't have catch-allocated locals. + self.label_locals_override + .push((after_all, locals_at_try_start.clone())); + + // Emit try body + self.bind_label(try_start); + for s in try_body { + self.gen_stmt(s)?; + } + self.bind_label(try_end); + + // Inline finally at end of try (if present), then goto after_all + if let Some(fin_body) = finally_body { + for s in fin_body { + self.gen_stmt(s)?; + } + } + self.emit_goto(after_all); + + // Save allocator state before catch/finally handlers so their locals don't leak + let saved_locals = self.locals.save(); + + // Emit each catch handler + let mut catch_handler_labels = Vec::new(); + for catch in catches { + // Restore allocator for each catch handler so previous catch locals don't leak + self.locals.restore(saved_locals.clone()); + let handler_label = self.new_label(); + catch_handler_labels.push(handler_label); + // Register as exception handler for frame tracking. + // For multi-catch, use java/lang/Throwable as the stack map type since + // we cannot compute the least common ancestor without class hierarchy + // knowledge. For single-catch, use the exact exception type. + let default_type = TypeName::Class("java/lang/Exception".into()); + let first_type = catch.exception_types.first().unwrap_or(&default_type); + let frame_class_name = if catch.exception_types.len() > 1 { + "java/lang/Throwable" + } else { + match first_type { + TypeName::Class(name) => name.as_str(), + _ => "java/lang/Exception", + } + }; + let internal = resolve_class_name(frame_class_name); + let ex_class_idx = self.class_file.get_or_add_class(&internal); + self.exception_handler_labels.push(( + handler_label, + VType::Object(ex_class_idx), + locals_at_try_start.clone(), + )); + self.bind_label(handler_label); + + // astore exception to a local. + // For multi-catch, the JVM verifier treats the exception variable as the + // LCA type (java/lang/Throwable), matching the stack-map frame type. + // Using a more-specific type like the first declared type would produce + // a verifier error because the StackMapTable frame recorded Throwable on + // the stack, so after astore the local is Throwable — not IllegalArgException. + let (local_type, ex_vtype) = if catch.exception_types.len() > 1 { + let throwable_ty = TypeName::Class("java/lang/Throwable".into()); + let vtype = type_name_to_vtype_resolved(&throwable_ty, self.class_file); + (throwable_ty, vtype) + } else { + let vtype = type_name_to_vtype_resolved(first_type, self.class_file); + (first_type.clone(), vtype) + }; + let ex_slot = self + .locals + .allocate_with_vtype(&catch.var_name, &local_type, ex_vtype); + self.emit_store(&local_type, ex_slot); + + // Emit catch body, tracking whether it ends with an unconditional transfer. + let before_body = self.instructions.len(); + for s in &catch.body { + self.gen_stmt(s)?; + } + let body_ends_with_transfer = + self.instructions.len() > before_body && self.last_is_unconditional_transfer(); + + // Inline finally (if present) + if let Some(fin_body) = finally_body { + for s in fin_body { + self.gen_stmt(s)?; + } + } + // Skip goto if the catch body already produced an unconditional transfer + // (e.g., `continue`, `break`, `return`, `throw`). Emitting a dead goto + // after an unconditional branch requires a StackMapTable frame that the + // JVM verifier would complain about. + if !body_ends_with_transfer { + self.emit_goto(after_all); + } + } + + // If finally present: emit catch-all handler that stores exception, runs finally, rethrows + let catch_all_label = if finally_body.is_some() { + self.locals.restore(saved_locals.clone()); + let label = self.new_label(); + // Register as exception handler for frame tracking + let throwable_idx = self.class_file.get_or_add_class("java/lang/Throwable"); + self.exception_handler_labels.push(( + label, + VType::Object(throwable_idx), + locals_at_try_start.clone(), + )); + self.bind_label(label); + let ex_ty = TypeName::Class("java/lang/Throwable".into()); + let ex_vtype = type_name_to_vtype_resolved(&ex_ty, self.class_file); + let ex_slot = self + .locals + .allocate_with_vtype("__finally_ex", &ex_ty, ex_vtype); + self.emit_store(&ex_ty, ex_slot); + if let Some(fin_body) = finally_body { + for s in fin_body { + self.gen_stmt(s)?; + } + } + self.emit_load(&ex_ty, ex_slot); + self.emit(Instruction::Athrow); + Some(label) + } else { + None + }; + + // Restore allocator so catch/finally locals don't leak into subsequent code + self.locals.restore(saved_locals); + + self.bind_label(after_all); + + // Register exception table entries (one per exception type per catch) + for (i, catch) in catches.iter().enumerate() { + for exc_type in &catch.exception_types { + let internal = resolve_class_name(match exc_type { + TypeName::Class(name) => name.as_str(), + _ => "java/lang/Exception", + }); + let catch_type = self.class_file.get_or_add_class(&internal); + self.pending_exceptions.push(PendingExceptionEntry { + start_label: try_start, + end_label: try_end, + handler_label: catch_handler_labels[i], + catch_type, + }); + } + } + + // Catch-all for finally + if let Some(catch_all) = catch_all_label { + self.pending_exceptions.push(PendingExceptionEntry { + start_label: try_start, + end_label: try_end, + handler_label: catch_all, + catch_type: 0, // 0 = catch all + }); + } + + Ok(()) + } + + // --- Synchronized codegen --- + + fn gen_synchronized(&mut self, lock_expr: &CExpr, body: &[CStmt]) -> Result<(), CompileError> { + let lock_ty = TypeName::Class("java/lang/Object".into()); + let lock_vtype = type_name_to_vtype_resolved(&lock_ty, self.class_file); + let lock_slot = self + .locals + .allocate_with_vtype("__sync_lock", &lock_ty, lock_vtype); + + // Evaluate lock expression, dup, store to temp, monitorenter + self.gen_expr(lock_expr)?; + self.emit(Instruction::Dup); + self.emit_store(&lock_ty, lock_slot); + self.emit(Instruction::Monitorenter); + + let try_start = self.new_label(); + let try_end = self.new_label(); + let catch_handler = self.new_label(); + let after_all = self.new_label(); + + // Capture locals at try-start for exception handler frames + let locals_at_try_start = self.locals.current_locals_vtypes(); + self.label_locals_override + .push((after_all, locals_at_try_start.clone())); + + // Try body + self.bind_label(try_start); + for s in body { + self.gen_stmt(s)?; + } + self.bind_label(try_end); + + // Normal exit: monitorexit + goto after_all + self.emit_load(&lock_ty, lock_slot); + self.emit(Instruction::Monitorexit); + self.emit_goto(after_all); + + // Save allocator state before catch-all handler so its locals don't leak + let saved_locals = self.locals.save(); + + // Catch-all handler: store exception, monitorexit, rethrow + let throwable_idx = self.class_file.get_or_add_class("java/lang/Throwable"); + self.exception_handler_labels.push(( + catch_handler, + VType::Object(throwable_idx), + locals_at_try_start, + )); + self.bind_label(catch_handler); + + let ex_ty = TypeName::Class("java/lang/Throwable".into()); + let sync_ex_vtype = type_name_to_vtype_resolved(&ex_ty, self.class_file); + let ex_slot = self + .locals + .allocate_with_vtype("__sync_ex", &ex_ty, sync_ex_vtype); + self.emit_store(&ex_ty, ex_slot); + self.emit_load(&lock_ty, lock_slot); + self.emit(Instruction::Monitorexit); + self.emit_load(&ex_ty, ex_slot); + self.emit(Instruction::Athrow); + + // Restore allocator so catch-all locals don't leak + self.locals.restore(saved_locals); + + self.bind_label(after_all); + + // Exception table: [try_start, try_end) → catch_handler, catch_type=0 (catch all) + self.pending_exceptions.push(PendingExceptionEntry { + start_label: try_start, + end_label: try_end, + handler_label: catch_handler, + catch_type: 0, + }); + + Ok(()) + } + + // --- Method/field resolution --- + + fn gen_method_call( + &mut self, + object: Option<&CExpr>, + name: &str, + args: &[CExpr], + ) -> Result<(), CompileError> { + match object { + Some(obj) => { + // Check for string concatenation: "something" + ... results in StringBuilder pattern + // But first, check if this is a chain like System.out.println + // Try to resolve as chain: detect Class.field.method pattern + if let Some((class_name, field_chain, is_static_root)) = self.resolve_dot_chain(obj) + && is_static_root + { + if field_chain.is_empty() { + // Direct static method call: ClassName.method(args) + return self.gen_static_method_call(&class_name, name, args); + } + // Static field access chain, e.g. System.out.println + self.gen_static_chain_method_call(&class_name, &field_chain, name, args)?; + return Ok(()); + } + + // Regular instance method call + self.gen_expr(obj)?; + for arg in args { + self.gen_expr(arg)?; + } + // Try to find the method ref in the constant pool + let descriptor = self.find_method_descriptor_in_pool(name, args)?; + let class_name = self.infer_receiver_class(obj)?; + if self.is_interface_method(&class_name, name) { + let method_idx = self.class_file.get_or_add_interface_method_ref( + &class_name, + name, + &descriptor, + ); + let count = compute_invokeinterface_count(&descriptor); + self.emit(Instruction::Invokeinterface { + index: method_idx, + count, + filler: 0, + }); + } else { + let method_idx = + self.class_file + .get_or_add_method_ref(&class_name, name, &descriptor); + self.emit(Instruction::Invokevirtual(method_idx)); + } + Ok(()) + } + None => { + // Unqualified method call - call on `this` + if !self.is_static { + self.emit(Instruction::Aload0); // this + } + for arg in args { + self.gen_expr(arg)?; + } + let descriptor = self.find_method_descriptor_in_pool(name, args)?; + let this_class = self.get_this_class_name()?; + let method_idx = + self.class_file + .get_or_add_method_ref(&this_class, name, &descriptor); + if self.is_static { + self.emit(Instruction::Invokestatic(method_idx)); + } else { + self.emit(Instruction::Invokevirtual(method_idx)); + } + Ok(()) + } + } + } + + fn gen_static_method_call( + &mut self, + class_name: &str, + name: &str, + args: &[CExpr], + ) -> Result<(), CompileError> { + for arg in args { + self.gen_expr(arg)?; + } + let internal = resolve_class_name(class_name); + let descriptor = self.find_method_descriptor_in_pool(name, args)?; + let method_idx = self + .class_file + .get_or_add_method_ref(&internal, name, &descriptor); + self.emit(Instruction::Invokestatic(method_idx)); + Ok(()) + } + + fn gen_field_access(&mut self, object: &CExpr, name: &str) -> Result<(), CompileError> { + // Check if this is a static field access (e.g., System.out) + if let CExpr::Ident(ident_name) = object { + // Check constant pool for a FieldRef with this class name + let resolved = resolve_class_name(ident_name); + if self.has_field_ref_for_class(&resolved, name) + || ident_name + .chars() + .next() + .is_some_and(|c| c.is_ascii_uppercase()) + { + return self.gen_static_field_access(ident_name, name); + } + } + // Handle array.length → arraylength instruction + if name == "length" { + let ty = self.infer_expr_type(object); + if matches!(ty, TypeName::Array(_)) { + self.gen_expr(object)?; + self.emit(Instruction::Arraylength); + return Ok(()); + } + } + self.gen_expr(object)?; + let class_name = self.infer_receiver_class(object)?; + let descriptor = self.find_field_descriptor_in_pool(&class_name, name)?; + let field_idx = self + .class_file + .get_or_add_field_ref(&class_name, name, &descriptor); + self.emit(Instruction::Getfield(field_idx)); + Ok(()) + } + + fn gen_static_field_access( + &mut self, + class_name: &str, + name: &str, + ) -> Result<(), CompileError> { + let internal = resolve_class_name(class_name); + let descriptor = self.find_field_descriptor_in_pool(&internal, name)?; + let field_idx = self + .class_file + .get_or_add_field_ref(&internal, name, &descriptor); + self.emit(Instruction::Getstatic(field_idx)); + Ok(()) + } + + /// Handle chains like System.out.println(x): + /// System is the class, out is the static field, println is the method + fn gen_static_chain_method_call( + &mut self, + class_name: &str, + field_chain: &[String], + method_name: &str, + args: &[CExpr], + ) -> Result<(), CompileError> { + let internal = resolve_class_name(class_name); + + // Start with the static field access + let first_field = &field_chain[0]; + let field_desc = self.find_field_descriptor_in_pool(&internal, first_field)?; + let field_idx = self + .class_file + .get_or_add_field_ref(&internal, first_field, &field_desc); + self.emit(Instruction::Getstatic(field_idx)); + + // For subsequent fields in the chain, use getfield + let mut current_type_desc = field_desc; + for field_name in &field_chain[1..] { + let field_class = descriptor_to_internal(¤t_type_desc)?; + let fd = self.find_field_descriptor_in_pool(&field_class, field_name)?; + let fi = self + .class_file + .get_or_add_field_ref(&field_class, field_name, &fd); + self.emit(Instruction::Getfield(fi)); + current_type_desc = fd; + } + + // Now emit args and invoke the method on the field's type + for arg in args { + self.gen_expr(arg)?; + } + + let receiver_class = descriptor_to_internal(¤t_type_desc)?; + let method_desc = self.find_method_descriptor_in_pool(method_name, args)?; + let method_idx = + self.class_file + .get_or_add_method_ref(&receiver_class, method_name, &method_desc); + self.emit(Instruction::Invokevirtual(method_idx)); + Ok(()) + } + + /// Try to resolve a dot-chain to (root_class, [field_names], is_static). + fn resolve_dot_chain(&self, expr: &CExpr) -> Option<(String, Vec, bool)> { + let mut fields = Vec::new(); + let mut current = expr; + + loop { + match current { + CExpr::FieldAccess { object, name } => { + fields.push(name.clone()); + current = object; + } + CExpr::Ident(name) => { + // Check if this is a class name (starts with uppercase) + if name.chars().next().is_some_and(|c| c.is_ascii_uppercase()) { + fields.reverse(); + return Some((name.clone(), fields, true)); + } + return None; + } + _ => return None, + } + } + } + + // --- Store target --- + + fn gen_store_target(&mut self, target: &CExpr) -> Result<(), CompileError> { + match target { + CExpr::Ident(name) => { + let (slot, ty) = + self.locals + .find(name) + .ok_or_else(|| CompileError::CodegenError { + message: format!("undefined variable: {}", name), + })?; + let ty = ty.clone(); + self.emit_store(&ty, slot); + Ok(()) + } + CExpr::FieldAccess { object, name } => { + // value is on stack, we need [objectref, value] for putfield + let class_name = self.infer_receiver_class(object)?; + let descriptor = self.find_field_descriptor_in_pool(&class_name, name)?; + let field_idx = + self.class_file + .get_or_add_field_ref(&class_name, name, &descriptor); + + let value_ty = descriptor_to_type(&descriptor); + if type_slot_width(&value_ty) == 2 { + // Category-2 value (long/double): Swap won't work. + // Store value to temp, push objectref, reload value. + let tmp_vtype = type_name_to_vtype_resolved(&value_ty, self.class_file); + let temp = self + .locals + .allocate_with_vtype("__field_tmp", &value_ty, tmp_vtype); + self.emit_store(&value_ty, temp); + self.gen_expr(object)?; + self.emit_load(&value_ty, temp); + } else { + // Category-1: Swap works fine + self.gen_expr(object)?; + self.emit(Instruction::Swap); + } + self.emit(Instruction::Putfield(field_idx)); + Ok(()) + } + CExpr::ArrayAccess { array, index } => { + // Stack has: [..., value]. We need: [..., arrayref, index, value] + // Strategy: store value to temp, push arrayref+index, reload value, store + let array_ty = self.infer_expr_type(array); + let elem_ty = match &array_ty { + TypeName::Array(inner) => inner.as_ref().clone(), + _ => TypeName::Primitive(PrimitiveKind::Int), + }; + let arr_tmp_vtype = type_name_to_vtype_resolved(&elem_ty, self.class_file); + let temp = + self.locals + .allocate_with_vtype("__arr_store_tmp", &elem_ty, arr_tmp_vtype); + self.emit_store(&elem_ty, temp); + self.gen_expr(array)?; + self.gen_expr(index)?; + self.emit_load(&elem_ty, temp); + self.emit_array_store(&array_ty); + Ok(()) + } + _ => Err(CompileError::CodegenError { + message: "invalid assignment target".into(), + }), + } + } + + // --- Helper: does this expression leave a value on the stack? --- + + fn expr_leaves_value(&self, expr: &CExpr) -> bool { + match expr { + CExpr::Assign { .. } => { + // All assignments leave the assigned value on the stack (arrays use dup_x2). + true + } + CExpr::CompoundAssign { .. } => true, + CExpr::PostIncrement(_) | CExpr::PostDecrement(_) => true, + CExpr::PreIncrement(_) | CExpr::PreDecrement(_) => true, + CExpr::MethodCall { + object, name, args, .. + } => { + // Check if the method returns void by looking up in pool + // For simplicity, assume non-void unless we can determine otherwise + // Well-known void methods: + if name == "println" || name == "print" || name == "close" || name == "flush" { + // These are commonly void but we still need to check + // For now, check if the call is on a PrintStream-like object + if let Some(obj) = object + && self.is_print_stream_chain(obj) + { + return false; + } + } + // Try to find method descriptor to check return type + if let Ok(desc) = self.find_method_descriptor_in_pool(name, args) + && desc.ends_with(")V") + { + return false; + } + true + } + CExpr::StaticMethodCall { name, args, .. } => { + if let Ok(desc) = self.find_method_descriptor_in_pool(name, args) + && desc.ends_with(")V") + { + return false; + } + true + } + _ => true, + } + } + + fn is_print_stream_chain(&self, expr: &CExpr) -> bool { + // Detect System.out pattern + if let CExpr::FieldAccess { object, name } = expr + && (name == "out" || name == "err") + && let CExpr::Ident(class_name) = object.as_ref() + && class_name == "System" + { + return true; + } + false + } + + // --- Instruction emission helpers --- + + fn emit_int_const(&mut self, value: i64) { + match value { + -1 => { + self.emit(Instruction::Iconstm1); + } + 0 => { + self.emit(Instruction::Iconst0); + } + 1 => { + self.emit(Instruction::Iconst1); + } + 2 => { + self.emit(Instruction::Iconst2); + } + 3 => { + self.emit(Instruction::Iconst3); + } + 4 => { + self.emit(Instruction::Iconst4); + } + 5 => { + self.emit(Instruction::Iconst5); + } + v if (-128..=127).contains(&v) => { + self.emit(Instruction::Bipush(v as i8)); + } + v if (-32768..=32767).contains(&v) => { + self.emit(Instruction::Sipush(v as i16)); + } + v => { + let cp_idx = self.class_file.get_or_add_integer(v as i32); + self.emit_ldc(cp_idx); + } + } + } + + fn emit_long_const(&mut self, value: i64) { + match value { + 0 => { + self.emit(Instruction::Lconst0); + } + 1 => { + self.emit(Instruction::Lconst1); + } + _ => { + let cp_idx = self.class_file.get_or_add_long(value); + self.emit(Instruction::Ldc2W(cp_idx)); + } + } + } + + fn emit_float_const(&mut self, value: f32) { + if value == 0.0 && value.is_sign_positive() { + self.emit(Instruction::Fconst0); + } else if value == 1.0 { + self.emit(Instruction::Fconst1); + } else if value == 2.0 { + self.emit(Instruction::Fconst2); + } else { + let cp_idx = self.class_file.get_or_add_float(value); + self.emit_ldc(cp_idx); + } + } + + fn emit_double_const(&mut self, value: f64) { + if value == 0.0 && value.is_sign_positive() { + self.emit(Instruction::Dconst0); + } else if value == 1.0 { + self.emit(Instruction::Dconst1); + } else { + let cp_idx = self.class_file.get_or_add_double(value); + self.emit(Instruction::Ldc2W(cp_idx)); + } + } + + fn emit_ldc(&mut self, cp_idx: u16) { + if cp_idx <= 255 { + self.emit(Instruction::Ldc(cp_idx as u8)); + } else { + self.emit(Instruction::LdcW(cp_idx)); + } + } + + fn emit_load(&mut self, ty: &TypeName, slot: u16) { + if is_reference_type(ty) { + match slot { + 0 => self.emit(Instruction::Aload0), + 1 => self.emit(Instruction::Aload1), + 2 => self.emit(Instruction::Aload2), + 3 => self.emit(Instruction::Aload3), + s if s <= 255 => self.emit(Instruction::Aload(s as u8)), + s => self.emit(Instruction::AloadWide(s)), + }; + } else if is_long_type(ty) { + match slot { + 0 => self.emit(Instruction::Lload0), + 1 => self.emit(Instruction::Lload1), + 2 => self.emit(Instruction::Lload2), + 3 => self.emit(Instruction::Lload3), + s if s <= 255 => self.emit(Instruction::Lload(s as u8)), + s => self.emit(Instruction::LloadWide(s)), + }; + } else if is_float_type(ty) { + match slot { + 0 => self.emit(Instruction::Fload0), + 1 => self.emit(Instruction::Fload1), + 2 => self.emit(Instruction::Fload2), + 3 => self.emit(Instruction::Fload3), + s if s <= 255 => self.emit(Instruction::Fload(s as u8)), + s => self.emit(Instruction::FloadWide(s)), + }; + } else if is_double_type(ty) { + match slot { + 0 => self.emit(Instruction::Dload0), + 1 => self.emit(Instruction::Dload1), + 2 => self.emit(Instruction::Dload2), + 3 => self.emit(Instruction::Dload3), + s if s <= 255 => self.emit(Instruction::Dload(s as u8)), + s => self.emit(Instruction::DloadWide(s)), + }; + } else { + // int and friends + match slot { + 0 => self.emit(Instruction::Iload0), + 1 => self.emit(Instruction::Iload1), + 2 => self.emit(Instruction::Iload2), + 3 => self.emit(Instruction::Iload3), + s if s <= 255 => self.emit(Instruction::Iload(s as u8)), + s => self.emit(Instruction::IloadWide(s)), + }; + } + } + + fn emit_store(&mut self, ty: &TypeName, slot: u16) { + if is_reference_type(ty) { + match slot { + 0 => self.emit(Instruction::Astore0), + 1 => self.emit(Instruction::Astore1), + 2 => self.emit(Instruction::Astore2), + 3 => self.emit(Instruction::Astore3), + s if s <= 255 => self.emit(Instruction::Astore(s as u8)), + s => self.emit(Instruction::AstoreWide(s)), + }; + } else if is_long_type(ty) { + match slot { + 0 => self.emit(Instruction::Lstore0), + 1 => self.emit(Instruction::Lstore1), + 2 => self.emit(Instruction::Lstore2), + 3 => self.emit(Instruction::Lstore3), + s if s <= 255 => self.emit(Instruction::Lstore(s as u8)), + s => self.emit(Instruction::LstoreWide(s)), + }; + } else if is_float_type(ty) { + match slot { + 0 => self.emit(Instruction::Fstore0), + 1 => self.emit(Instruction::Fstore1), + 2 => self.emit(Instruction::Fstore2), + 3 => self.emit(Instruction::Fstore3), + s if s <= 255 => self.emit(Instruction::Fstore(s as u8)), + s => self.emit(Instruction::FstoreWide(s)), + }; + } else if is_double_type(ty) { + match slot { + 0 => self.emit(Instruction::Dstore0), + 1 => self.emit(Instruction::Dstore1), + 2 => self.emit(Instruction::Dstore2), + 3 => self.emit(Instruction::Dstore3), + s if s <= 255 => self.emit(Instruction::Dstore(s as u8)), + s => self.emit(Instruction::DstoreWide(s)), + }; + } else { + match slot { + 0 => self.emit(Instruction::Istore0), + 1 => self.emit(Instruction::Istore1), + 2 => self.emit(Instruction::Istore2), + 3 => self.emit(Instruction::Istore3), + s if s <= 255 => self.emit(Instruction::Istore(s as u8)), + s => self.emit(Instruction::IstoreWide(s)), + }; + } + } + + // --- Typed instruction emission --- + + fn emit_typed_binary_op(&mut self, op: &BinOp, ty: &TypeName) -> Result<(), CompileError> { + if is_long_type(ty) { + let instr = match op { + BinOp::Add => Instruction::Ladd, + BinOp::Sub => Instruction::Lsub, + BinOp::Mul => Instruction::Lmul, + BinOp::Div => Instruction::Ldiv, + BinOp::Rem => Instruction::Lrem, + BinOp::Shl => Instruction::Lshl, + BinOp::Shr => Instruction::Lshr, + BinOp::Ushr => Instruction::Lushr, + BinOp::BitAnd => Instruction::Land, + BinOp::BitOr => Instruction::Lor, + BinOp::BitXor => Instruction::Lxor, + }; + self.emit(instr); + } else if is_float_type(ty) { + let instr = match op { + BinOp::Add => Instruction::Fadd, + BinOp::Sub => Instruction::Fsub, + BinOp::Mul => Instruction::Fmul, + BinOp::Div => Instruction::Fdiv, + BinOp::Rem => Instruction::Frem, + _ => { + return Err(CompileError::CodegenError { + message: format!("bitwise/shift operator {:?} is not valid on float", op), + }); + } + }; + self.emit(instr); + } else if is_double_type(ty) { + let instr = match op { + BinOp::Add => Instruction::Dadd, + BinOp::Sub => Instruction::Dsub, + BinOp::Mul => Instruction::Dmul, + BinOp::Div => Instruction::Ddiv, + BinOp::Rem => Instruction::Drem, + _ => { + return Err(CompileError::CodegenError { + message: format!("bitwise/shift operator {:?} is not valid on double", op), + }); + } + }; + self.emit(instr); + } else { + // int and sub-int types + let instr = match op { + BinOp::Add => Instruction::Iadd, + BinOp::Sub => Instruction::Isub, + BinOp::Mul => Instruction::Imul, + BinOp::Div => Instruction::Idiv, + BinOp::Rem => Instruction::Irem, + BinOp::Shl => Instruction::Ishl, + BinOp::Shr => Instruction::Ishr, + BinOp::Ushr => Instruction::Iushr, + BinOp::BitAnd => Instruction::Iand, + BinOp::BitOr => Instruction::Ior, + BinOp::BitXor => Instruction::Ixor, + }; + self.emit(instr); + } + Ok(()) + } + + /// Emit a typed comparison instruction for non-int types. + /// For long: lcmp; for float: fcmpl/fcmpg; for double: dcmpl/dcmpg. + /// The result is an int that can be used with ifeq/ifne/iflt/ifge/ifgt/ifle. + fn emit_typed_compare(&mut self, ty: &TypeName, op: &CompareOp) { + if is_long_type(ty) { + self.emit(Instruction::Lcmp); + } else if is_float_type(ty) { + // fcmpg for > and >= (NaN → 1, so false branch taken correctly) + // fcmpl for everything else (NaN → -1) + match op { + CompareOp::Gt | CompareOp::Ge => self.emit(Instruction::Fcmpg), + _ => self.emit(Instruction::Fcmpl), + }; + } else if is_double_type(ty) { + match op { + CompareOp::Gt | CompareOp::Ge => self.emit(Instruction::Dcmpg), + _ => self.emit(Instruction::Dcmpl), + }; + } + } + + // --- String concatenation --- + + /// Check if a BinaryOp::Add expression involves string concatenation. + fn is_string_concat(&self, expr: &CExpr) -> bool { + if let CExpr::BinaryOp { + op: BinOp::Add, + left, + right, + } = expr + { + let lt = self.infer_expr_type(left); + let rt = self.infer_expr_type(right); + is_string_type(<) + || is_string_type(&rt) + || self.is_string_concat(left) + || self.is_string_concat(right) + } else { + false + } + } + + /// Flatten a chain of BinaryOp::Add nodes into a list of parts for StringBuilder. + fn flatten_string_concat<'b>(&self, expr: &'b CExpr, parts: &mut Vec<&'b CExpr>) { + if let CExpr::BinaryOp { + op: BinOp::Add, + left, + right, + } = expr + && self.is_string_concat(expr) + { + self.flatten_string_concat(left, parts); + self.flatten_string_concat(right, parts); + return; + } + parts.push(expr); + } + + /// Generate StringBuilder-based string concatenation. + fn gen_string_concat(&mut self, expr: &CExpr) -> Result<(), CompileError> { + let mut parts = Vec::new(); + self.flatten_string_concat(expr, &mut parts); + + // new StringBuilder() + let sb_class = self.class_file.get_or_add_class("java/lang/StringBuilder"); + self.emit(Instruction::New(sb_class)); + self.emit(Instruction::Dup); + let init_idx = + self.class_file + .get_or_add_method_ref("java/lang/StringBuilder", "", "()V"); + self.emit(Instruction::Invokespecial(init_idx)); + + // .append(part) for each part + for part in &parts { + let desc = self.infer_append_descriptor(part); + self.gen_expr(part)?; + let append_idx = + self.class_file + .get_or_add_method_ref("java/lang/StringBuilder", "append", &desc); + self.emit(Instruction::Invokevirtual(append_idx)); + } + + // .toString() + let tostring_idx = self.class_file.get_or_add_method_ref( + "java/lang/StringBuilder", + "toString", + "()Ljava/lang/String;", + ); + self.emit(Instruction::Invokevirtual(tostring_idx)); + Ok(()) + } + + // --- Type inference --- + + /// Infer the type that an expression produces on the JVM stack. + fn infer_expr_type(&self, expr: &CExpr) -> TypeName { + match expr { + CExpr::IntLiteral(_) => TypeName::Primitive(PrimitiveKind::Int), + CExpr::LongLiteral(_) => TypeName::Primitive(PrimitiveKind::Long), + CExpr::FloatLiteral(_) => TypeName::Primitive(PrimitiveKind::Float), + CExpr::DoubleLiteral(_) => TypeName::Primitive(PrimitiveKind::Double), + CExpr::BoolLiteral(_) => TypeName::Primitive(PrimitiveKind::Boolean), + CExpr::CharLiteral(_) => TypeName::Primitive(PrimitiveKind::Char), + CExpr::StringLiteral(_) => TypeName::Class("String".into()), + CExpr::NullLiteral => TypeName::Class("Object".into()), + CExpr::This => TypeName::Class("this".into()), + CExpr::Ident(name) => { + if let Some((_, ty)) = self.locals.find(name) { + ty.clone() + } else { + TypeName::Primitive(PrimitiveKind::Int) + } + } + CExpr::BinaryOp { op, left, right } => { + let lt = self.infer_expr_type(left); + let rt = self.infer_expr_type(right); + if *op == BinOp::Add && (is_string_type(<) || is_string_type(&rt)) { + TypeName::Class("String".into()) + } else { + promote_numeric_type(<, &rt) + } + } + CExpr::UnaryOp { operand, .. } => self.infer_expr_type(operand), + CExpr::Comparison { .. } + | CExpr::LogicalAnd(_, _) + | CExpr::LogicalOr(_, _) + | CExpr::LogicalNot(_) + | CExpr::Instanceof { .. } => TypeName::Primitive(PrimitiveKind::Boolean), + CExpr::Cast { ty, .. } => ty.clone(), + CExpr::Assign { value, .. } => self.infer_expr_type(value), + CExpr::CompoundAssign { target, .. } => self.infer_expr_type(target), + CExpr::PreIncrement(e) + | CExpr::PreDecrement(e) + | CExpr::PostIncrement(e) + | CExpr::PostDecrement(e) => self.infer_expr_type(e), + CExpr::NewObject { class_name, .. } => TypeName::Class(class_name.clone()), + CExpr::NewArray { element_type, .. } => TypeName::Array(Box::new(element_type.clone())), + CExpr::NewMultiArray { + element_type, + dimensions, + } => { + let mut ty = element_type.clone(); + for _ in 0..dimensions.len() { + ty = TypeName::Array(Box::new(ty)); + } + ty + } + CExpr::SwitchExpr { + cases, + default_expr, + .. + } => { + if let Some(first_case) = cases.first() { + self.infer_expr_type(&first_case.expr) + } else { + self.infer_expr_type(default_expr) + } + } + CExpr::Lambda { .. } | CExpr::MethodRef { .. } => TypeName::Class("Object".into()), + CExpr::ArrayAccess { array, .. } => { + match self.infer_expr_type(array) { + TypeName::Array(inner) => *inner, + _ => TypeName::Primitive(PrimitiveKind::Int), // fallback + } + } + CExpr::Ternary { then_expr, .. } => self.infer_expr_type(then_expr), + CExpr::MethodCall { name, args, .. } | CExpr::StaticMethodCall { name, args, .. } => { + // Try to infer return type from method descriptor + if let Ok(desc) = self.find_method_descriptor_in_pool(name, args) + && let Some(ret_start) = desc.rfind(')') + { + let ret_desc = &desc[ret_start + 1..]; + return descriptor_to_type(ret_desc); + } + TypeName::Class("Object".into()) + } + CExpr::FieldAccess { object, name } => { + // array.length → int + if name == "length" { + let obj_ty = self.infer_expr_type(object); + if matches!(obj_ty, TypeName::Array(_)) { + return TypeName::Primitive(PrimitiveKind::Int); + } + } + TypeName::Class("Object".into()) + } + CExpr::StaticFieldAccess { .. } => TypeName::Class("Object".into()), + } + } + + /// Emit a widening conversion if `from` is narrower than `to`. + fn emit_widen_if_needed(&mut self, from: &TypeName, to: &TypeName) { + if numeric_rank(from) >= numeric_rank(to) { + return; + } + match (numeric_rank(from), numeric_rank(to)) { + (0, 1) => { + self.emit(Instruction::I2l); + } // int → long + (0, 2) => { + self.emit(Instruction::I2f); + } // int → float + (0, 3) => { + self.emit(Instruction::I2d); + } // int → double + (1, 2) => { + self.emit(Instruction::L2f); + } // long → float + (1, 3) => { + self.emit(Instruction::L2d); + } // long → double + (2, 3) => { + self.emit(Instruction::F2d); + } // float → double + _ => {} + } + } + + /// Emit a type-appropriate array load instruction based on the array type. + fn emit_array_load(&mut self, array_ty: &TypeName) { + match array_ty { + TypeName::Array(inner) => match inner.as_ref() { + TypeName::Primitive(PrimitiveKind::Int) => { + self.emit(Instruction::Iaload); + } + TypeName::Primitive(PrimitiveKind::Long) => { + self.emit(Instruction::Laload); + } + TypeName::Primitive(PrimitiveKind::Float) => { + self.emit(Instruction::Faload); + } + TypeName::Primitive(PrimitiveKind::Double) => { + self.emit(Instruction::Daload); + } + TypeName::Primitive(PrimitiveKind::Byte | PrimitiveKind::Boolean) => { + self.emit(Instruction::Baload); + } + TypeName::Primitive(PrimitiveKind::Char) => { + self.emit(Instruction::Caload); + } + TypeName::Primitive(PrimitiveKind::Short) => { + self.emit(Instruction::Saload); + } + _ => { + self.emit(Instruction::Aaload); + } + }, + _ => { + self.emit(Instruction::Aaload); + } + } + } + + /// Emit a type-appropriate array store instruction based on the array type. + fn emit_array_store(&mut self, array_ty: &TypeName) { + match array_ty { + TypeName::Array(inner) => match inner.as_ref() { + TypeName::Primitive(PrimitiveKind::Int) => { + self.emit(Instruction::Iastore); + } + TypeName::Primitive(PrimitiveKind::Long) => { + self.emit(Instruction::Lastore); + } + TypeName::Primitive(PrimitiveKind::Float) => { + self.emit(Instruction::Fastore); + } + TypeName::Primitive(PrimitiveKind::Double) => { + self.emit(Instruction::Dastore); + } + TypeName::Primitive(PrimitiveKind::Byte | PrimitiveKind::Boolean) => { + self.emit(Instruction::Bastore); + } + TypeName::Primitive(PrimitiveKind::Char) => { + self.emit(Instruction::Castore); + } + TypeName::Primitive(PrimitiveKind::Short) => { + self.emit(Instruction::Sastore); + } + _ => { + self.emit(Instruction::Aastore); + } + }, + _ => { + self.emit(Instruction::Aastore); + } + } + } + + /// Emit the constant `1` in the appropriate type. + fn emit_typed_const_one(&mut self, ty: &TypeName) { + if is_long_type(ty) { + self.emit(Instruction::Lconst1); + } else if is_float_type(ty) { + self.emit(Instruction::Fconst1); + } else if is_double_type(ty) { + self.emit(Instruction::Dconst1); + } else { + self.emit(Instruction::Iconst1); + } + } + + /// Emit a narrowing conversion from a wider type back to a target type. + fn emit_narrow(&mut self, from: &TypeName, to: &TypeName) { + let from_rank = numeric_rank(from); + let to_rank = numeric_rank(to); + if from_rank <= to_rank { + return; + } + match (from_rank, to_rank) { + (3, 2) => { + self.emit(Instruction::D2f); + } + (3, 1) => { + self.emit(Instruction::D2l); + } + (3, 0) => { + self.emit(Instruction::D2i); + } + (2, 1) => { + self.emit(Instruction::F2l); + } + (2, 0) => { + self.emit(Instruction::F2i); + } + (1, 0) => { + self.emit(Instruction::L2i); + } + _ => {} + } + } + + // --- Type/descriptor resolution helpers --- + + fn get_this_class_name(&self) -> Result { + use crate::constant_info::ConstantInfo; + let this_class = self.class_file.this_class; + match &self.class_file.const_pool[(this_class - 1) as usize] { + ConstantInfo::Class(c) => self + .class_file + .get_utf8(c.name_index) + .map(|s| s.to_string()) + .ok_or_else(|| CompileError::CodegenError { + message: "could not resolve this class name".into(), + }), + _ => Err(CompileError::CodegenError { + message: "this_class does not point to a Class constant".into(), + }), + } + } + + fn infer_receiver_class(&self, expr: &CExpr) -> Result { + match expr { + CExpr::This => self.get_this_class_name(), + CExpr::Ident(name) => { + // Check local variable type + if let Some((_, ty)) = self.locals.find(name) + && let TypeName::Class(class_name) = ty + { + return Ok(resolve_class_name(class_name)); + } + // Might be a class name for static access + if name.chars().next().is_some_and(|c| c.is_ascii_uppercase()) { + return Ok(resolve_class_name(name)); + } + Err(CompileError::CodegenError { + message: format!("cannot infer class for receiver '{}'", name), + }) + } + CExpr::FieldAccess { + object, + name: field_name, + } => { + // Try to find the field's type in the constant pool + if let Ok(class_name) = self.infer_receiver_class(object) + && let Ok(desc) = self.find_field_descriptor_in_pool(&class_name, field_name) + && let Ok(internal) = descriptor_to_internal(&desc) + { + return Ok(internal); + } + Err(CompileError::CodegenError { + message: format!("cannot infer class for field access '{}'", field_name), + }) + } + CExpr::MethodCall { + object, name, args, .. + } => { + // Try to infer from method return type in pool + if let Some(obj) = object + && let Ok(_owner_class) = self.infer_receiver_class(obj) + && let Ok(desc) = self.find_method_descriptor_in_pool(name, args) + { + // Extract return type from descriptor (everything after ')') + if let Some(ret_desc) = desc.rsplit(')').next() + && let Ok(internal) = descriptor_to_internal(ret_desc) + { + return Ok(internal); + } + } + Err(CompileError::CodegenError { + message: format!("cannot infer class for method call '{}'", name), + }) + } + CExpr::StringLiteral(_) => Ok("java/lang/String".into()), + CExpr::NewObject { class_name, .. } => Ok(resolve_class_name(class_name)), + _ => Err(CompileError::CodegenError { + message: "cannot infer receiver class".into(), + }), + } + } + + /// Resolve method descriptor. For well-known overloaded methods, infers from + /// argument types. For others, searches the constant pool. + fn find_method_descriptor_in_pool( + &self, + method_name: &str, + args: &[CExpr], + ) -> Result { + // Well-known overloaded methods — infer from arg types first to avoid + // picking the wrong overload from the pool + match method_name { + "println" => { + return match args.len() { + 0 => Ok("()V".into()), + 1 => Ok(self.infer_println_descriptor(&args[0])), + _ => Ok("(Ljava/lang/String;)V".into()), + }; + } + "print" => { + return match args.len() { + 1 => Ok(self.infer_println_descriptor(&args[0])), + _ => Ok("(Ljava/lang/String;)V".into()), + }; + } + "append" if args.len() == 1 => return Ok(self.infer_append_descriptor(&args[0])), + "toString" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "equals" if args.len() == 1 => return Ok("(Ljava/lang/Object;)Z".into()), + "hashCode" if args.is_empty() => return Ok("()I".into()), + "length" if args.is_empty() => return Ok("()I".into()), + "charAt" if args.len() == 1 => return Ok("(I)C".into()), + "substring" if args.len() == 1 => return Ok("(I)Ljava/lang/String;".into()), + "substring" if args.len() == 2 => return Ok("(II)Ljava/lang/String;".into()), + "valueOf" if args.len() == 1 => { + return Ok("(Ljava/lang/Object;)Ljava/lang/String;".into()); + } + "parseInt" if args.len() == 1 => return Ok("(Ljava/lang/String;)I".into()), + "getClass" if args.is_empty() => return Ok("()Ljava/lang/Class;".into()), + "getName" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "getSimpleName" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "getCanonicalName" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "getMessage" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "getCause" if args.is_empty() => return Ok("()Ljava/lang/Throwable;".into()), + "getStackTrace" if args.is_empty() => { + return Ok("()[Ljava/lang/StackTraceElement;".into()); + } + "trim" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "toUpperCase" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "toLowerCase" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "intern" if args.is_empty() => return Ok("()Ljava/lang/String;".into()), + "isEmpty" if args.is_empty() => return Ok("()Z".into()), + "startsWith" if args.len() == 1 => return Ok("(Ljava/lang/String;)Z".into()), + "endsWith" if args.len() == 1 => return Ok("(Ljava/lang/String;)Z".into()), + "contains" if args.len() == 1 => return Ok("(Ljava/lang/CharSequence;)Z".into()), + "replace" if args.len() == 2 => { + return Ok( + "(Ljava/lang/CharSequence;Ljava/lang/CharSequence;)Ljava/lang/String;".into(), + ); + } + "split" if args.len() == 1 => { + return Ok("(Ljava/lang/String;)[Ljava/lang/String;".into()); + } + "toCharArray" if args.is_empty() => return Ok("()[C".into()), + "format" if !args.is_empty() => { + return Ok(format!("(Ljava/lang/String;{}V", "[Ljava/lang/Object;)")); + } + _ => {} + } + + // Search constant pool for matching method + use crate::constant_info::ConstantInfo; + let pool = &self.class_file.const_pool; + + for entry in pool.iter() { + let nat_index = match entry { + ConstantInfo::MethodRef(r) => r.name_and_type_index, + ConstantInfo::InterfaceMethodRef(r) => r.name_and_type_index, + _ => continue, + }; + if let ConstantInfo::NameAndType(nat) = &pool[(nat_index - 1) as usize] + && let Some(name) = self.class_file.get_utf8(nat.name_index) + && name == method_name + && let Some(desc) = self.class_file.get_utf8(nat.descriptor_index) + && let Some((params, _)) = parse_method_descriptor(desc) + && params.len() == args.len() + { + return Ok(desc.to_string()); + } + } + + // Not found in pool — infer descriptor from argument types + Ok(self.infer_method_descriptor(method_name, args)) + } + + fn infer_println_descriptor(&self, arg: &CExpr) -> String { + match arg { + CExpr::StringLiteral(_) => "(Ljava/lang/String;)V".into(), + CExpr::IntLiteral(_) => "(I)V".into(), + CExpr::LongLiteral(_) => "(J)V".into(), + CExpr::FloatLiteral(_) => "(F)V".into(), + CExpr::DoubleLiteral(_) => "(D)V".into(), + CExpr::BoolLiteral(_) => "(Z)V".into(), + CExpr::CharLiteral(_) => "(C)V".into(), + CExpr::Ident(name) => { + if let Some((_, ty)) = self.locals.find(name) { + match ty { + TypeName::Primitive(PrimitiveKind::Int) + | TypeName::Primitive(PrimitiveKind::Byte) + | TypeName::Primitive(PrimitiveKind::Short) => return "(I)V".into(), + TypeName::Primitive(PrimitiveKind::Long) => return "(J)V".into(), + TypeName::Primitive(PrimitiveKind::Float) => return "(F)V".into(), + TypeName::Primitive(PrimitiveKind::Double) => return "(D)V".into(), + TypeName::Primitive(PrimitiveKind::Boolean) => return "(Z)V".into(), + TypeName::Primitive(PrimitiveKind::Char) => return "(C)V".into(), + TypeName::Class(name) if name == "String" || name == "java.lang.String" => { + return "(Ljava/lang/String;)V".into(); + } + _ => {} + } + } + "(Ljava/lang/Object;)V".into() + } + // For compound expressions, infer the result type + _ => { + let ty = self.infer_expr_type(arg); + match &ty { + TypeName::Primitive(PrimitiveKind::Int) + | TypeName::Primitive(PrimitiveKind::Byte) + | TypeName::Primitive(PrimitiveKind::Short) => "(I)V".into(), + TypeName::Primitive(PrimitiveKind::Long) => "(J)V".into(), + TypeName::Primitive(PrimitiveKind::Float) => "(F)V".into(), + TypeName::Primitive(PrimitiveKind::Double) => "(D)V".into(), + TypeName::Primitive(PrimitiveKind::Boolean) => "(Z)V".into(), + TypeName::Primitive(PrimitiveKind::Char) => "(C)V".into(), + TypeName::Class(name) + if name == "String" + || name == "java.lang.String" + || name == "java/lang/String" => + { + "(Ljava/lang/String;)V".into() + } + _ => "(Ljava/lang/Object;)V".into(), + } + } + } + } + + fn infer_append_descriptor(&self, arg: &CExpr) -> String { + let ty = self.infer_expr_type(arg); + match &ty { + TypeName::Primitive(PrimitiveKind::Int) + | TypeName::Primitive(PrimitiveKind::Byte) + | TypeName::Primitive(PrimitiveKind::Short) => "(I)Ljava/lang/StringBuilder;".into(), + TypeName::Primitive(PrimitiveKind::Long) => "(J)Ljava/lang/StringBuilder;".into(), + TypeName::Primitive(PrimitiveKind::Float) => "(F)Ljava/lang/StringBuilder;".into(), + TypeName::Primitive(PrimitiveKind::Double) => "(D)Ljava/lang/StringBuilder;".into(), + TypeName::Primitive(PrimitiveKind::Boolean) => "(Z)Ljava/lang/StringBuilder;".into(), + TypeName::Primitive(PrimitiveKind::Char) => "(C)Ljava/lang/StringBuilder;".into(), + TypeName::Class(name) + if name == "String" || name == "java.lang.String" || name == "java/lang/String" => + { + "(Ljava/lang/String;)Ljava/lang/StringBuilder;".into() + } + _ => "(Ljava/lang/Object;)Ljava/lang/StringBuilder;".into(), + } + } + + fn infer_constructor_descriptor(&self, args: &[CExpr]) -> Result { + if args.is_empty() { + return Ok("()V".into()); + } + // Build descriptor from arg types + let mut desc = String::from("("); + for arg in args { + desc.push_str(&self.infer_arg_descriptor(arg)); + } + desc.push_str(")V"); + Ok(desc) + } + + fn infer_arg_descriptor(&self, arg: &CExpr) -> String { + match arg { + CExpr::StringLiteral(_) => "Ljava/lang/String;".into(), + CExpr::IntLiteral(_) => "I".into(), + CExpr::LongLiteral(_) => "J".into(), + CExpr::FloatLiteral(_) => "F".into(), + CExpr::DoubleLiteral(_) => "D".into(), + CExpr::BoolLiteral(_) => "Z".into(), + CExpr::CharLiteral(_) => "C".into(), + CExpr::Ident(name) => { + if let Some((_, ty)) = self.locals.find(name) { + return type_name_to_descriptor(ty); + } + "Ljava/lang/Object;".into() + } + _ => "Ljava/lang/Object;".into(), + } + } + + /// Infer a method descriptor from argument types and method name heuristics. + /// Used as fallback when the method isn't found in the constant pool. + fn infer_method_descriptor(&self, method_name: &str, args: &[CExpr]) -> String { + // Well-known collection/generic methods with fixed descriptors (type-erased signatures) + match (method_name, args.len()) { + ("add", 1) => return "(Ljava/lang/Object;)Z".into(), + ("add", 2) => return "(ILjava/lang/Object;)V".into(), + ("get", 1) => return "(I)Ljava/lang/Object;".into(), + ("set", 2) => return "(ILjava/lang/Object;)Ljava/lang/Object;".into(), + ("remove", 1) => return "(Ljava/lang/Object;)Z".into(), + ("contains", 1) => return "(Ljava/lang/Object;)Z".into(), + ("put", 2) => return "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;".into(), + ("containsKey", 1) => return "(Ljava/lang/Object;)Z".into(), + ("containsValue", 1) => return "(Ljava/lang/Object;)Z".into(), + ("size", 0) => return "()I".into(), + ("isEmpty", 0) => return "()Z".into(), + ("clear", 0) => return "()V".into(), + ("iterator", 0) => return "()Ljava/util/Iterator;".into(), + ("hasNext", 0) => return "()Z".into(), + ("next", 0) => return "()Ljava/lang/Object;".into(), + ("offer", 1) => return "(Ljava/lang/Object;)Z".into(), + _ => {} + } + + // Generic fallback: infer from arg types + let mut desc = String::from("("); + for arg in args { + desc.push_str(&self.infer_arg_descriptor(arg)); + } + desc.push(')'); + + // Heuristic return type based on well-known method names + let ret = match method_name { + "size" | "length" | "indexOf" | "lastIndexOf" | "compareTo" | "read" | "intValue" + | "ordinal" | "hashCode" => "I", + "isEmpty" | "contains" | "containsKey" | "containsValue" | "hasNext" + | "hasPrevious" => "Z", + "longValue" => "J", + "floatValue" => "F", + "doubleValue" => "D", + "charAt" => "C", + _ => "Ljava/lang/Object;", + }; + desc.push_str(ret); + desc + } + + /// Check if the constant pool has a FieldRef for the given class and field. + fn has_field_ref_for_class(&self, class_name: &str, field_name: &str) -> bool { + use crate::constant_info::ConstantInfo; + let pool = &self.class_file.const_pool; + for entry in pool.iter() { + if let ConstantInfo::FieldRef(r) = entry + && let ConstantInfo::Class(c) = &pool[(r.class_index - 1) as usize] + && let Some(cn) = self.class_file.get_utf8(c.name_index) + && cn == class_name + && let ConstantInfo::NameAndType(nat) = &pool[(r.name_and_type_index - 1) as usize] + && let Some(name) = self.class_file.get_utf8(nat.name_index) + && name == field_name + { + return true; + } + } + false + } + + /// Check if a class is a known interface or has InterfaceMethodRef entries in the pool. + fn is_interface_method(&self, class_name: &str, method_name: &str) -> bool { + // Well-known JDK interfaces + const KNOWN_INTERFACES: &[&str] = &[ + "java/util/List", + "java/util/Map", + "java/util/Set", + "java/util/Collection", + "java/util/Iterator", + "java/util/Iterable", + "java/util/Enumeration", + "java/util/Comparator", + "java/util/Deque", + "java/util/Queue", + "java/util/SortedMap", + "java/util/SortedSet", + "java/util/NavigableMap", + "java/util/NavigableSet", + "java/util/concurrent/Callable", + "java/util/concurrent/Future", + "java/util/concurrent/BlockingQueue", + "java/util/concurrent/BlockingDeque", + "java/lang/Runnable", + "java/lang/Comparable", + "java/lang/CharSequence", + "java/lang/Appendable", + "java/lang/AutoCloseable", + "java/lang/Closeable", + "java/io/Closeable", + "java/io/Serializable", + "java/io/Flushable", + ]; + if KNOWN_INTERFACES.contains(&class_name) { + return true; + } + + // Check the constant pool for InterfaceMethodRef entries matching this class+method + use crate::constant_info::ConstantInfo; + let pool = &self.class_file.const_pool; + for entry in pool.iter() { + if let ConstantInfo::InterfaceMethodRef(r) = entry + && let ConstantInfo::Class(c) = &pool[(r.class_index - 1) as usize] + && let Some(cn) = self.class_file.get_utf8(c.name_index) + && cn == class_name + && let ConstantInfo::NameAndType(nat) = &pool[(r.name_and_type_index - 1) as usize] + && let Some(name) = self.class_file.get_utf8(nat.name_index) + && name == method_name + { + return true; + } + } + false + } + + fn find_field_descriptor_in_pool( + &self, + class_name: &str, + field_name: &str, + ) -> Result { + use crate::constant_info::ConstantInfo; + let pool = &self.class_file.const_pool; + + for entry in pool.iter() { + if let ConstantInfo::FieldRef(r) = entry { + // Check class + if let ConstantInfo::Class(c) = &pool[(r.class_index - 1) as usize] + && let Some(cn) = self.class_file.get_utf8(c.name_index) + && cn == class_name + && let ConstantInfo::NameAndType(nat) = + &pool[(r.name_and_type_index - 1) as usize] + && let Some(name) = self.class_file.get_utf8(nat.name_index) + && name == field_name + && let Some(desc) = self.class_file.get_utf8(nat.descriptor_index) + { + return Ok(desc.to_string()); + } + } + } + + // Well-known fields + match (class_name, field_name) { + ("java/lang/System", "out") => Ok("Ljava/io/PrintStream;".into()), + ("java/lang/System", "err") => Ok("Ljava/io/PrintStream;".into()), + ("java/lang/System", "in") => Ok("Ljava/io/InputStream;".into()), + ("java/lang/Boolean", "TRUE") => Ok("Ljava/lang/Boolean;".into()), + ("java/lang/Boolean", "FALSE") => Ok("Ljava/lang/Boolean;".into()), + _ => Err(CompileError::CodegenError { + message: format!( + "cannot find field '{}.{}' in constant pool", + class_name, field_name + ), + }), + } + } +} + +// --- Utility functions --- + +fn resolve_label_addr( + label_id: usize, + labels: &[Option], + addresses: &[u32], + end_addr: i32, +) -> Result { + let target_instr = labels[label_id].ok_or_else(|| CompileError::CodegenError { + message: format!("unresolved label {}", label_id), + })?; + if target_instr < addresses.len() { + Ok(addresses[target_instr] as i32) + } else { + Ok(end_addr) + } +} + +/// Compute the byte offset at a specific instruction index. +fn compute_byte_offset_at(instructions: &[Instruction], target_idx: usize) -> u32 { + let mut addr = 0u32; + for (i, instr) in instructions.iter().enumerate() { + if i == target_idx { + return addr; + } + addr += instruction_byte_size(instr, addr); + } + addr // past-end offset +} + +pub(crate) fn compute_byte_addresses(instructions: &[Instruction]) -> Vec { + let mut addresses = Vec::with_capacity(instructions.len()); + let mut addr = 0u32; + for instr in instructions { + addresses.push(addr); + addr += instruction_byte_size(instr, addr); + } + addresses +} + +fn patch_branch_offset(instr: &Instruction, offset: i16) -> Result { + Ok(match instr { + Instruction::Goto(_) => Instruction::Goto(offset), + Instruction::Ifeq(_) => Instruction::Ifeq(offset), + Instruction::Ifne(_) => Instruction::Ifne(offset), + Instruction::Iflt(_) => Instruction::Iflt(offset), + Instruction::Ifge(_) => Instruction::Ifge(offset), + Instruction::Ifgt(_) => Instruction::Ifgt(offset), + Instruction::Ifle(_) => Instruction::Ifle(offset), + Instruction::IfIcmpeq(_) => Instruction::IfIcmpeq(offset), + Instruction::IfIcmpne(_) => Instruction::IfIcmpne(offset), + Instruction::IfIcmplt(_) => Instruction::IfIcmplt(offset), + Instruction::IfIcmpge(_) => Instruction::IfIcmpge(offset), + Instruction::IfIcmpgt(_) => Instruction::IfIcmpgt(offset), + Instruction::IfIcmple(_) => Instruction::IfIcmple(offset), + Instruction::IfAcmpeq(_) => Instruction::IfAcmpeq(offset), + Instruction::IfAcmpne(_) => Instruction::IfAcmpne(offset), + Instruction::Ifnull(_) => Instruction::Ifnull(offset), + Instruction::Ifnonnull(_) => Instruction::Ifnonnull(offset), + _ => { + return Err(CompileError::CodegenError { + message: format!("cannot patch branch offset on {:?}", instr), + }); + } + }) +} + +/// Resolve a simple or dotted class name to JVM internal form. +/// Erase a JvmType to its Object-erased descriptor for SAM type descriptors. +/// Reference types and arrays become `Ljava/lang/Object;`, primitives stay as-is. +fn erase_to_object(ty: &JvmType) -> String { + match ty { + JvmType::Reference(_) | JvmType::Array(_) | JvmType::Null | JvmType::Unknown => { + "Ljava/lang/Object;".into() + } + other => other.to_descriptor(), + } +} + +pub fn resolve_class_name(name: &str) -> String { + // Well-known short names + match name { + "String" => return "java/lang/String".into(), + "Object" => return "java/lang/Object".into(), + "System" => return "java/lang/System".into(), + "Integer" => return "java/lang/Integer".into(), + "Long" => return "java/lang/Long".into(), + "Float" => return "java/lang/Float".into(), + "Double" => return "java/lang/Double".into(), + "Boolean" => return "java/lang/Boolean".into(), + "Byte" => return "java/lang/Byte".into(), + "Character" => return "java/lang/Character".into(), + "Short" => return "java/lang/Short".into(), + "Math" => return "java/lang/Math".into(), + "StringBuilder" => return "java/lang/StringBuilder".into(), + "StringBuffer" => return "java/lang/StringBuffer".into(), + "PrintStream" => return "java/io/PrintStream".into(), + "InputStream" => return "java/io/InputStream".into(), + "Exception" => return "java/lang/Exception".into(), + "RuntimeException" => return "java/lang/RuntimeException".into(), + "NullPointerException" => return "java/lang/NullPointerException".into(), + "IllegalArgumentException" => return "java/lang/IllegalArgumentException".into(), + "IllegalStateException" => return "java/lang/IllegalStateException".into(), + "UnsupportedOperationException" => return "java/lang/UnsupportedOperationException".into(), + "IndexOutOfBoundsException" => return "java/lang/IndexOutOfBoundsException".into(), + "ArrayList" => return "java/util/ArrayList".into(), + "HashMap" => return "java/util/HashMap".into(), + "List" => return "java/util/List".into(), + "Map" => return "java/util/Map".into(), + "Set" => return "java/util/Set".into(), + "Arrays" => return "java/util/Arrays".into(), + "Collections" => return "java/util/Collections".into(), + "Class" => return "java/lang/Class".into(), + _ => {} + } + // Convert dotted name to internal form + name.replace('.', "/") +} + +fn type_name_to_descriptor(ty: &TypeName) -> String { + match ty { + TypeName::Primitive(kind) => match kind { + PrimitiveKind::Int => "I".into(), + PrimitiveKind::Long => "J".into(), + PrimitiveKind::Float => "F".into(), + PrimitiveKind::Double => "D".into(), + PrimitiveKind::Boolean => "Z".into(), + PrimitiveKind::Byte => "B".into(), + PrimitiveKind::Char => "C".into(), + PrimitiveKind::Short => "S".into(), + PrimitiveKind::Void => "V".into(), + }, + TypeName::Class(name) => { + let internal = resolve_class_name(name); + format!("L{};", internal) + } + TypeName::Array(inner) => { + format!("[{}", type_name_to_descriptor(inner)) + } + } +} + +fn descriptor_to_internal(desc: &str) -> Result { + if desc.starts_with('L') && desc.ends_with(';') { + Ok(desc[1..desc.len() - 1].to_string()) + } else { + Err(CompileError::CodegenError { + message: format!( + "cannot convert descriptor '{}' to internal class name", + desc + ), + }) + } +} + +/// Convert a field descriptor string to a TypeName. +fn descriptor_to_type(desc: &str) -> TypeName { + match desc { + "I" => TypeName::Primitive(PrimitiveKind::Int), + "J" => TypeName::Primitive(PrimitiveKind::Long), + "F" => TypeName::Primitive(PrimitiveKind::Float), + "D" => TypeName::Primitive(PrimitiveKind::Double), + "Z" => TypeName::Primitive(PrimitiveKind::Boolean), + "B" => TypeName::Primitive(PrimitiveKind::Byte), + "C" => TypeName::Primitive(PrimitiveKind::Char), + "S" => TypeName::Primitive(PrimitiveKind::Short), + "V" => TypeName::Primitive(PrimitiveKind::Void), + _ if desc.starts_with('L') && desc.ends_with(';') => { + TypeName::Class(desc[1..desc.len() - 1].to_string()) + } + _ if desc.starts_with('[') => TypeName::Array(Box::new(descriptor_to_type(&desc[1..]))), + _ => TypeName::Class("java/lang/Object".into()), + } +} + +fn jvm_type_to_type_name(jvm_ty: &JvmType) -> TypeName { + match jvm_ty { + JvmType::Int => TypeName::Primitive(PrimitiveKind::Int), + JvmType::Long => TypeName::Primitive(PrimitiveKind::Long), + JvmType::Float => TypeName::Primitive(PrimitiveKind::Float), + JvmType::Double => TypeName::Primitive(PrimitiveKind::Double), + JvmType::Boolean => TypeName::Primitive(PrimitiveKind::Boolean), + JvmType::Byte => TypeName::Primitive(PrimitiveKind::Byte), + JvmType::Char => TypeName::Primitive(PrimitiveKind::Char), + JvmType::Short => TypeName::Primitive(PrimitiveKind::Short), + JvmType::Void => TypeName::Primitive(PrimitiveKind::Void), + JvmType::Reference(name) => TypeName::Class(name.clone()), + JvmType::Array(inner) => TypeName::Array(Box::new(jvm_type_to_type_name(inner))), + JvmType::Null | JvmType::Unknown => TypeName::Class("java/lang/Object".into()), + } +} + +fn is_int_type(ty: &TypeName) -> bool { + matches!( + ty, + TypeName::Primitive( + PrimitiveKind::Int + | PrimitiveKind::Boolean + | PrimitiveKind::Byte + | PrimitiveKind::Char + | PrimitiveKind::Short + ) + ) +} + +fn is_long_type(ty: &TypeName) -> bool { + matches!(ty, TypeName::Primitive(PrimitiveKind::Long)) +} + +fn is_float_type(ty: &TypeName) -> bool { + matches!(ty, TypeName::Primitive(PrimitiveKind::Float)) +} + +fn is_double_type(ty: &TypeName) -> bool { + matches!(ty, TypeName::Primitive(PrimitiveKind::Double)) +} + +fn is_reference_type(ty: &TypeName) -> bool { + matches!(ty, TypeName::Class(_) | TypeName::Array(_)) +} + +fn type_slot_width(ty: &TypeName) -> u16 { + match ty { + TypeName::Primitive(PrimitiveKind::Long | PrimitiveKind::Double) => 2, + _ => 1, + } +} + +fn type_name_to_vtype(ty: &TypeName) -> VType { + match ty { + TypeName::Primitive(kind) => match kind { + PrimitiveKind::Int + | PrimitiveKind::Boolean + | PrimitiveKind::Byte + | PrimitiveKind::Char + | PrimitiveKind::Short => VType::Integer, + PrimitiveKind::Long => VType::Long, + PrimitiveKind::Float => VType::Float, + PrimitiveKind::Double => VType::Double, + PrimitiveKind::Void => VType::Top, + }, + TypeName::Class(_) | TypeName::Array(_) => { + // We'd need a class_file reference to resolve the cp index. + // Use a sentinel — this will be resolved later or use Null as fallback. + VType::Null + } + } +} + +/// Resolve a TypeName to a VType, using the class file to get or create constant pool entries. +fn type_name_to_vtype_resolved(ty: &TypeName, class_file: &mut ClassFile) -> VType { + match ty { + TypeName::Primitive(kind) => match kind { + PrimitiveKind::Int + | PrimitiveKind::Boolean + | PrimitiveKind::Byte + | PrimitiveKind::Char + | PrimitiveKind::Short => VType::Integer, + PrimitiveKind::Long => VType::Long, + PrimitiveKind::Float => VType::Float, + PrimitiveKind::Double => VType::Double, + PrimitiveKind::Void => VType::Top, + }, + TypeName::Class(name) => { + let internal = resolve_class_name(name); + let idx = class_file.get_or_add_class(&internal); + VType::Object(idx) + } + TypeName::Array(_) => { + let desc = type_name_to_descriptor(ty); + let idx = class_file.get_or_add_class(&desc); + VType::Object(idx) + } + } +} + +/// Resolve a JvmType to a VType, using the class file to get or create constant pool entries. +fn jvm_type_to_vtype_resolved(jvm_ty: &JvmType, class_file: &mut ClassFile) -> VType { + match jvm_ty { + JvmType::Int | JvmType::Boolean | JvmType::Byte | JvmType::Char | JvmType::Short => { + VType::Integer + } + JvmType::Long => VType::Long, + JvmType::Float => VType::Float, + JvmType::Double => VType::Double, + JvmType::Void => VType::Top, + JvmType::Reference(name) => { + let idx = class_file.get_or_add_class(name); + VType::Object(idx) + } + JvmType::Array(_) => { + let ty = jvm_type_to_type_name(jvm_ty); + let desc = type_name_to_descriptor(&ty); + let idx = class_file.get_or_add_class(&desc); + VType::Object(idx) + } + JvmType::Null | JvmType::Unknown => VType::Null, + } +} + +fn is_var_sentinel(ty: &TypeName) -> bool { + matches!(ty, TypeName::Class(name) if name == "__var__") +} + +fn is_string_type(ty: &TypeName) -> bool { + matches!(ty, TypeName::Class(name) if name == "String" || name == "java.lang.String" || name == "java/lang/String") +} + +/// Java numeric widening rank: Int(0) < Long(1) < Float(2) < Double(3). +fn numeric_rank(ty: &TypeName) -> u8 { + match ty { + TypeName::Primitive(PrimitiveKind::Double) => 3, + TypeName::Primitive(PrimitiveKind::Float) => 2, + TypeName::Primitive(PrimitiveKind::Long) => 1, + _ => 0, // int and sub-int types + } +} + +/// Compute the `count` argument for invokeinterface: 1 (objectref) + sum of arg slot widths. +fn compute_invokeinterface_count(descriptor: &str) -> u8 { + let (params, _) = parse_method_descriptor(descriptor).unwrap_or((Vec::new(), JvmType::Void)); + let mut count: u8 = 1; // objectref + for param in ¶ms { + count += if param.is_wide() { 2 } else { 1 }; + } + count +} + +/// Return the wider of two numeric types per Java widening rules. +fn promote_numeric_type(a: &TypeName, b: &TypeName) -> TypeName { + if numeric_rank(a) >= numeric_rank(b) { + a.clone() + } else { + b.clone() + } +} diff --git a/src/compile/lexer.rs b/src/compile/lexer.rs new file mode 100644 index 0000000..a36c8dd --- /dev/null +++ b/src/compile/lexer.rs @@ -0,0 +1,780 @@ +use super::CompileError; + +#[derive(Clone, Debug, PartialEq)] +pub enum Token { + // Literals + IntLiteral(i64), + LongLiteral(i64), + FloatLiteral(f64), + DoubleLiteral(f64), + StringLiteral(String), + CharLiteral(char), + + // Identifiers and keywords + Ident(String), + If, + Else, + While, + For, + Return, + New, + This, + Throw, + Break, + Continue, + Instanceof, + Switch, + Case, + Default, + Try, + Catch, + Finally, + Null, + True, + False, + Synchronized, + Var, + + // Primitive type keywords + KwInt, + KwLong, + KwFloat, + KwDouble, + KwBoolean, + KwByte, + KwChar, + KwShort, + KwVoid, + + // Operators + Plus, + Minus, + Star, + Slash, + Percent, + Amp, + Pipe, + Caret, + Tilde, + Bang, + AmpAmp, + PipePipe, + Eq, + EqEq, + BangEq, + Lt, + LtEq, + Gt, + GtEq, + LtLt, + GtGt, + GtGtGt, + PlusEq, + MinusEq, + StarEq, + SlashEq, + PercentEq, + AmpEq, + PipeEq, + CaretEq, + LtLtEq, + GtGtEq, + GtGtGtEq, + PlusPlus, + MinusMinus, + + Arrow, + ColonColon, + + // Delimiters + LParen, + RParen, + LBrace, + RBrace, + LBracket, + RBracket, + Semicolon, + Comma, + Dot, + Question, + Colon, + + // End of input + Eof, +} + +#[derive(Clone, Debug)] +pub struct SpannedToken { + pub token: Token, + pub line: usize, + pub column: usize, +} + +pub struct Lexer { + chars: Vec, + pos: usize, + line: usize, + column: usize, +} + +impl Lexer { + pub fn new(source: &str) -> Self { + Lexer { + chars: source.chars().collect(), + pos: 0, + line: 1, + column: 1, + } + } + + pub fn tokenize(mut self) -> Result, CompileError> { + let mut tokens = Vec::new(); + loop { + self.skip_whitespace_and_comments(); + if self.pos >= self.chars.len() { + tokens.push(SpannedToken { + token: Token::Eof, + line: self.line, + column: self.column, + }); + break; + } + tokens.push(self.next_token()?); + } + Ok(tokens) + } + + fn peek(&self) -> Option { + self.chars.get(self.pos).copied() + } + + fn peek_ahead(&self, n: usize) -> Option { + self.chars.get(self.pos + n).copied() + } + + fn advance(&mut self) -> Option { + let c = self.chars.get(self.pos).copied()?; + self.pos += 1; + if c == '\n' { + self.line += 1; + self.column = 1; + } else { + self.column += 1; + } + Some(c) + } + + fn skip_whitespace_and_comments(&mut self) { + loop { + // Skip whitespace + while self.peek().is_some_and(|c| c.is_whitespace()) { + self.advance(); + } + // Skip line comments + if self.peek() == Some('/') && self.peek_ahead(1) == Some('/') { + while self.peek().is_some_and(|c| c != '\n') { + self.advance(); + } + continue; + } + // Skip block comments + if self.peek() == Some('/') && self.peek_ahead(1) == Some('*') { + self.advance(); + self.advance(); + loop { + if self.peek().is_none() { + break; + } + if self.peek() == Some('*') && self.peek_ahead(1) == Some('/') { + self.advance(); + self.advance(); + break; + } + self.advance(); + } + continue; + } + break; + } + } + + fn error(&self, message: impl Into) -> CompileError { + CompileError::ParseError { + line: self.line, + column: self.column, + message: message.into(), + } + } + + fn next_token(&mut self) -> Result { + let line = self.line; + let column = self.column; + let c = self.peek().unwrap(); + + let token = match c { + '(' => { + self.advance(); + Token::LParen + } + ')' => { + self.advance(); + Token::RParen + } + '{' => { + self.advance(); + Token::LBrace + } + '}' => { + self.advance(); + Token::RBrace + } + '[' => { + self.advance(); + Token::LBracket + } + ']' => { + self.advance(); + Token::RBracket + } + ';' => { + self.advance(); + Token::Semicolon + } + ',' => { + self.advance(); + Token::Comma + } + '.' => { + self.advance(); + Token::Dot + } + '?' => { + self.advance(); + Token::Question + } + ':' => { + self.advance(); + if self.peek() == Some(':') { + self.advance(); + Token::ColonColon + } else { + Token::Colon + } + } + '~' => { + self.advance(); + Token::Tilde + } + '+' => { + self.advance(); + if self.peek() == Some('+') { + self.advance(); + Token::PlusPlus + } else if self.peek() == Some('=') { + self.advance(); + Token::PlusEq + } else { + Token::Plus + } + } + '-' => { + self.advance(); + if self.peek() == Some('-') { + self.advance(); + Token::MinusMinus + } else if self.peek() == Some('=') { + self.advance(); + Token::MinusEq + } else if self.peek() == Some('>') { + self.advance(); + Token::Arrow + } else { + Token::Minus + } + } + '*' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::StarEq + } else { + Token::Star + } + } + '/' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::SlashEq + } else { + Token::Slash + } + } + '%' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::PercentEq + } else { + Token::Percent + } + } + '&' => { + self.advance(); + if self.peek() == Some('&') { + self.advance(); + Token::AmpAmp + } else if self.peek() == Some('=') { + self.advance(); + Token::AmpEq + } else { + Token::Amp + } + } + '|' => { + self.advance(); + if self.peek() == Some('|') { + self.advance(); + Token::PipePipe + } else if self.peek() == Some('=') { + self.advance(); + Token::PipeEq + } else { + Token::Pipe + } + } + '^' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::CaretEq + } else { + Token::Caret + } + } + '!' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::BangEq + } else { + Token::Bang + } + } + '=' => { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::EqEq + } else { + Token::Eq + } + } + '<' => { + self.advance(); + if self.peek() == Some('<') { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::LtLtEq + } else { + Token::LtLt + } + } else if self.peek() == Some('=') { + self.advance(); + Token::LtEq + } else { + Token::Lt + } + } + '>' => { + self.advance(); + if self.peek() == Some('>') { + self.advance(); + if self.peek() == Some('>') { + self.advance(); + if self.peek() == Some('=') { + self.advance(); + Token::GtGtGtEq + } else { + Token::GtGtGt + } + } else if self.peek() == Some('=') { + self.advance(); + Token::GtGtEq + } else { + Token::GtGt + } + } else if self.peek() == Some('=') { + self.advance(); + Token::GtEq + } else { + Token::Gt + } + } + '"' => self.read_string()?, + '\'' => self.read_char()?, + _ if c.is_ascii_digit() => self.read_number()?, + _ if c.is_ascii_alphabetic() || c == '_' => self.read_ident_or_keyword(), + _ => return Err(self.error(format!("unexpected character: '{}'", c))), + }; + + Ok(SpannedToken { + token, + line, + column, + }) + } + + fn read_string(&mut self) -> Result { + self.advance(); // consume opening " + let mut s = String::new(); + loop { + match self.peek() { + None => return Err(self.error("unterminated string literal")), + Some('"') => { + self.advance(); + break; + } + Some('\\') => { + self.advance(); + match self.peek() { + Some('n') => { + self.advance(); + s.push('\n'); + } + Some('t') => { + self.advance(); + s.push('\t'); + } + Some('r') => { + self.advance(); + s.push('\r'); + } + Some('\\') => { + self.advance(); + s.push('\\'); + } + Some('"') => { + self.advance(); + s.push('"'); + } + Some('\'') => { + self.advance(); + s.push('\''); + } + Some('0') => { + self.advance(); + s.push('\0'); + } + _ => return Err(self.error("invalid escape sequence")), + } + } + Some(c) => { + self.advance(); + s.push(c); + } + } + } + Ok(Token::StringLiteral(s)) + } + + fn read_char(&mut self) -> Result { + self.advance(); // consume opening ' + let c = match self.peek() { + None => return Err(self.error("unterminated char literal")), + Some('\\') => { + self.advance(); + match self.peek() { + Some('n') => { + self.advance(); + '\n' + } + Some('t') => { + self.advance(); + '\t' + } + Some('r') => { + self.advance(); + '\r' + } + Some('\\') => { + self.advance(); + '\\' + } + Some('\'') => { + self.advance(); + '\'' + } + Some('0') => { + self.advance(); + '\0' + } + _ => return Err(self.error("invalid escape sequence in char literal")), + } + } + Some(c) => { + self.advance(); + c + } + }; + if self.peek() != Some('\'') { + return Err(self.error("expected closing ' in char literal")); + } + self.advance(); + Ok(Token::CharLiteral(c)) + } + + fn read_number(&mut self) -> Result { + let mut num_str = String::new(); + let mut is_float = false; + let mut is_long = false; + let mut is_explicit_float = false; + let mut is_explicit_double = false; + + // Handle hex + if self.peek() == Some('0') + && (self.peek_ahead(1) == Some('x') || self.peek_ahead(1) == Some('X')) + { + num_str.push(self.advance().unwrap()); // '0' + num_str.push(self.advance().unwrap()); // 'x' + while self + .peek() + .is_some_and(|c| c.is_ascii_hexdigit() || c == '_') + { + let c = self.advance().unwrap(); + if c != '_' { + num_str.push(c); + } + } + if self.peek() == Some('L') || self.peek() == Some('l') { + self.advance(); + let val = i64::from_str_radix(&num_str[2..], 16) + .map_err(|_| self.error("invalid hex long literal"))?; + return Ok(Token::LongLiteral(val)); + } + let val = i64::from_str_radix(&num_str[2..], 16) + .map_err(|_| self.error("invalid hex literal"))?; + return Ok(Token::IntLiteral(val)); + } + + // Regular decimal + while self.peek().is_some_and(|c| c.is_ascii_digit() || c == '_') { + let c = self.advance().unwrap(); + if c != '_' { + num_str.push(c); + } + } + + if self.peek() == Some('.') && self.peek_ahead(1).is_some_and(|c| c.is_ascii_digit()) { + is_float = true; + num_str.push(self.advance().unwrap()); // '.' + while self.peek().is_some_and(|c| c.is_ascii_digit() || c == '_') { + let c = self.advance().unwrap(); + if c != '_' { + num_str.push(c); + } + } + } + + // Exponent + if self.peek() == Some('e') || self.peek() == Some('E') { + is_float = true; + num_str.push(self.advance().unwrap()); + if self.peek() == Some('+') || self.peek() == Some('-') { + num_str.push(self.advance().unwrap()); + } + while self.peek().is_some_and(|c| c.is_ascii_digit()) { + num_str.push(self.advance().unwrap()); + } + } + + // Suffix + if self.peek() == Some('f') || self.peek() == Some('F') { + self.advance(); + is_float = true; + is_explicit_float = true; + } else if self.peek() == Some('d') || self.peek() == Some('D') { + self.advance(); + is_float = true; + is_explicit_double = true; + } else if self.peek() == Some('L') || self.peek() == Some('l') { + self.advance(); + is_long = true; + } + + if is_long { + let val: i64 = num_str + .parse() + .map_err(|_| self.error("invalid long literal"))?; + Ok(Token::LongLiteral(val)) + } else if is_float { + let val: f64 = num_str + .parse() + .map_err(|_| self.error("invalid float literal"))?; + if is_explicit_float { + Ok(Token::FloatLiteral(val)) + } else if is_explicit_double { + Ok(Token::DoubleLiteral(val)) + } else { + // Unadorned float literal defaults to double in Java + Ok(Token::DoubleLiteral(val)) + } + } else { + let val: i64 = num_str + .parse() + .map_err(|_| self.error("invalid integer literal"))?; + Ok(Token::IntLiteral(val)) + } + } + + fn read_ident_or_keyword(&mut self) -> Token { + let mut ident = String::new(); + while self + .peek() + .is_some_and(|c| c.is_ascii_alphanumeric() || c == '_') + { + ident.push(self.advance().unwrap()); + } + match ident.as_str() { + "if" => Token::If, + "else" => Token::Else, + "while" => Token::While, + "for" => Token::For, + "return" => Token::Return, + "new" => Token::New, + "this" => Token::This, + "throw" => Token::Throw, + "break" => Token::Break, + "continue" => Token::Continue, + "instanceof" => Token::Instanceof, + "switch" => Token::Switch, + "case" => Token::Case, + "default" => Token::Default, + "try" => Token::Try, + "catch" => Token::Catch, + "finally" => Token::Finally, + "null" => Token::Null, + "true" => Token::True, + "false" => Token::False, + "synchronized" => Token::Synchronized, + "var" => Token::Var, + "int" => Token::KwInt, + "long" => Token::KwLong, + "float" => Token::KwFloat, + "double" => Token::KwDouble, + "boolean" => Token::KwBoolean, + "byte" => Token::KwByte, + "char" => Token::KwChar, + "short" => Token::KwShort, + "void" => Token::KwVoid, + _ => Token::Ident(ident), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn lex(s: &str) -> Vec { + Lexer::new(s) + .tokenize() + .unwrap() + .into_iter() + .map(|st| st.token) + .collect() + } + + #[test] + fn test_simple_tokens() { + let tokens = lex("int x = 42;"); + assert_eq!( + tokens, + vec![ + Token::KwInt, + Token::Ident("x".into()), + Token::Eq, + Token::IntLiteral(42), + Token::Semicolon, + Token::Eof, + ] + ); + } + + #[test] + fn test_string_literal() { + let tokens = lex("\"hello\\nworld\""); + assert_eq!( + tokens, + vec![Token::StringLiteral("hello\nworld".into()), Token::Eof] + ); + } + + #[test] + fn test_operators() { + let tokens = lex("a++ && b-- || !c"); + assert_eq!( + tokens, + vec![ + Token::Ident("a".into()), + Token::PlusPlus, + Token::AmpAmp, + Token::Ident("b".into()), + Token::MinusMinus, + Token::PipePipe, + Token::Bang, + Token::Ident("c".into()), + Token::Eof, + ] + ); + } + + #[test] + fn test_float_literals() { + let tokens = lex("3.14f 2.0 1L"); + assert_eq!( + tokens, + vec![ + Token::FloatLiteral(3.14), + Token::DoubleLiteral(2.0), + Token::LongLiteral(1), + Token::Eof, + ] + ); + } + + #[test] + fn test_comments() { + let tokens = lex("a // comment\n+ b /* block */ + c"); + assert_eq!( + tokens, + vec![ + Token::Ident("a".into()), + Token::Plus, + Token::Ident("b".into()), + Token::Plus, + Token::Ident("c".into()), + Token::Eof, + ] + ); + } + + #[test] + fn test_shift_operators() { + let tokens = lex("a << b >> c >>> d"); + assert_eq!( + tokens, + vec![ + Token::Ident("a".into()), + Token::LtLt, + Token::Ident("b".into()), + Token::GtGt, + Token::Ident("c".into()), + Token::GtGtGt, + Token::Ident("d".into()), + Token::Eof, + ] + ); + } +} diff --git a/src/compile/mod.rs b/src/compile/mod.rs new file mode 100644 index 0000000..2bff185 --- /dev/null +++ b/src/compile/mod.rs @@ -0,0 +1,327 @@ +pub mod ast; +pub mod codegen; +pub mod lexer; +pub mod parser; +pub mod patch; +pub mod stack_calc; +pub mod stackmap; + +use std::fmt; + +use crate::ClassFile; +use crate::attribute_info::{ExceptionEntry, StackMapTableAttribute}; +use crate::code_attribute::Instruction; + +use self::ast::CStmt; +use self::codegen::CodeGenerator; +use self::lexer::Lexer; +use self::parser::Parser; + +#[derive(Clone, Debug)] +pub enum CompileError { + ParseError { + line: usize, + column: usize, + message: String, + }, + TypeError { + message: String, + }, + CodegenError { + message: String, + }, + MethodNotFound { + name: String, + }, +} + +impl fmt::Display for CompileError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CompileError::ParseError { + line, + column, + message, + } => write!(f, "parse error at {}:{}: {}", line, column, message), + CompileError::TypeError { message } => write!(f, "type error: {}", message), + CompileError::CodegenError { message } => write!(f, "codegen error: {}", message), + CompileError::MethodNotFound { name } => write!(f, "method not found: {}", name), + } + } +} + +impl std::error::Error for CompileError {} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub enum InsertMode { + /// Replace the entire method body (default). + #[default] + Replace, + /// Insert compiled code at the beginning, preserving the original body. + Prepend, + /// Insert compiled code at the end, after the original body. + /// + /// The trailing return instruction(s) of the original method are stripped + /// so the original code falls through to the appended code. The appended + /// code is responsible for returning. + Append, +} + +#[derive(Clone)] +pub struct CompileOptions { + pub strip_stack_map_table: bool, + pub generate_stack_map_table: bool, + pub insert_mode: InsertMode, +} + +impl Default for CompileOptions { + fn default() -> Self { + CompileOptions { + strip_stack_map_table: false, + generate_stack_map_table: true, + insert_mode: InsertMode::Replace, + } + } +} + +pub struct GeneratedCode { + pub instructions: Vec, + pub max_stack: u16, + pub max_locals: u16, + pub exception_table: Vec, + pub stack_map_table: Option, +} + +/// Parse a Java method body into AST statements. +pub fn parse_method_body(source: &str) -> Result, CompileError> { + let lexer = Lexer::new(source); + let tokens = lexer.tokenize()?; + let mut parser = Parser::new(tokens); + parser.parse_method_body() +} + +/// Generate bytecode from AST statements. +pub fn generate_bytecode( + stmts: &[CStmt], + class_file: &mut ClassFile, + is_static: bool, + method_descriptor: &str, +) -> Result { + generate_bytecode_with_options(stmts, class_file, is_static, method_descriptor, false) +} + +/// Generate bytecode from AST statements with options. +pub fn generate_bytecode_with_options( + stmts: &[CStmt], + class_file: &mut ClassFile, + is_static: bool, + method_descriptor: &str, + generate_stack_map_table: bool, +) -> Result { + let mut codegen = CodeGenerator::new_with_options( + class_file, + is_static, + method_descriptor, + generate_stack_map_table, + &[], + )?; + codegen.generate_body(stmts)?; + codegen.finish() +} + +/// Compile Java source and replace a method's body in the class file. +/// +/// When `method_descriptor` is `Some`, the method is matched by both name and +/// descriptor, disambiguating overloaded methods. When `None`, the first method +/// with the given name is used (legacy behavior). +pub fn compile_method_body( + source: &str, + class_file: &mut ClassFile, + method_name: &str, + method_descriptor: Option<&str>, + options: &CompileOptions, +) -> Result<(), CompileError> { + patch::compile_method_body_impl(source, class_file, method_name, method_descriptor, options) +} + +/// Compile Java source and append it after an existing method body. +/// +/// The trailing return instructions of the original body are stripped +/// so the original code falls through to the appended code. +pub fn append_method_body( + source: &str, + class_file: &mut ClassFile, + method_name: &str, + method_descriptor: Option<&str>, + options: &CompileOptions, +) -> Result<(), CompileError> { + let mut opts = options.clone(); + opts.insert_mode = InsertMode::Append; + patch::compile_method_body_impl(source, class_file, method_name, method_descriptor, &opts) +} + +/// Compile Java source and prepend it to an existing method body. +/// +/// The compiled code is inserted before the original instructions. +/// Trailing return instructions are stripped so the prepended code +/// falls through to the original body. +pub fn prepend_method_body( + source: &str, + class_file: &mut ClassFile, + method_name: &str, + method_descriptor: Option<&str>, + options: &CompileOptions, +) -> Result<(), CompileError> { + let mut opts = options.clone(); + opts.insert_mode = InsertMode::Prepend; + patch::compile_method_body_impl(source, class_file, method_name, method_descriptor, &opts) +} + +/// Compile and patch a single method body in a class file. +/// +/// Generates a valid StackMapTable by default so the patched class passes +/// full JVM bytecode verification. +/// +/// # Forms +/// +/// ```ignore +/// // With StackMapTable generation (passes full verification): +/// patch_method!(class_file, "main", r#"{ System.out.println("hello"); }"#)?; +/// +/// // Without StackMapTable (requires -noverify or -XX:-BytecodeVerification*): +/// patch_method!(class_file, "main", r#"{ System.out.println("hello"); }"#, no_verify)?; +/// ``` +#[macro_export] +macro_rules! patch_method { + ($class_file:expr, $method:expr, $source:expr) => { + $crate::compile::compile_method_body( + $source, + &mut $class_file, + $method, + None, + &$crate::compile::CompileOptions { + generate_stack_map_table: true, + ..$crate::compile::CompileOptions::default() + }, + ) + }; + ($class_file:expr, $method:expr, $source:expr, no_verify) => { + $crate::compile::compile_method_body( + $source, + &mut $class_file, + $method, + None, + &$crate::compile::CompileOptions::default(), + ) + }; +} + +/// Compile and patch multiple method bodies in a class file. +/// +/// Each method is compiled and patched in order. If any method fails, +/// the error is returned immediately and subsequent methods are not patched. +/// +/// Generates a valid StackMapTable by default. +/// +/// ```ignore +/// patch_methods!(class_file, { +/// "main" => r#"{ System.out.println("hello"); }"#, +/// "helper" => r#"{ return 42; }"#, +/// })?; +/// +/// // Without StackMapTable: +/// patch_methods!(class_file, no_verify, { +/// "main" => r#"{ System.out.println("hello"); }"#, +/// })?; +/// ``` +#[macro_export] +macro_rules! patch_methods { + ($class_file:expr, { $($method:expr => $source:expr),+ $(,)? }) => {{ + (|| -> Result<(), $crate::compile::CompileError> { + $( + $crate::patch_method!($class_file, $method, $source)?; + )+ + Ok(()) + })() + }}; + ($class_file:expr, no_verify, { $($method:expr => $source:expr),+ $(,)? }) => {{ + (|| -> Result<(), $crate::compile::CompileError> { + $( + $crate::patch_method!($class_file, $method, $source, no_verify)?; + )+ + Ok(()) + })() + }}; +} + +/// Prepend compiled Java source to the beginning of a method body. +/// +/// The original method code is preserved; the new code runs first and +/// falls through to the original instructions. +/// +/// ```ignore +/// prepend_method!(class_file, "main", r#"{ System.out.println("entering main"); }"#)?; +/// ``` +#[macro_export] +macro_rules! prepend_method { + ($class_file:expr, $method:expr, $source:expr) => { + $crate::compile::prepend_method_body( + $source, + &mut $class_file, + $method, + None, + &$crate::compile::CompileOptions { + generate_stack_map_table: true, + insert_mode: $crate::compile::InsertMode::Prepend, + ..$crate::compile::CompileOptions::default() + }, + ) + }; + ($class_file:expr, $method:expr, $source:expr, no_verify) => { + $crate::compile::prepend_method_body( + $source, + &mut $class_file, + $method, + None, + &$crate::compile::CompileOptions { + insert_mode: $crate::compile::InsertMode::Prepend, + ..$crate::compile::CompileOptions::default() + }, + ) + }; +} + +/// Append compiled Java source after the end of a method body. +/// +/// The original method's trailing return is stripped so it falls through +/// to the appended code. The appended code is responsible for returning. +/// +/// ```ignore +/// append_method!(class_file, "main", r#"{ System.out.println("exiting main"); }"#)?; +/// ``` +#[macro_export] +macro_rules! append_method { + ($class_file:expr, $method:expr, $source:expr) => { + $crate::compile::append_method_body( + $source, + &mut $class_file, + $method, + &$crate::compile::CompileOptions { + generate_stack_map_table: true, + insert_mode: $crate::compile::InsertMode::Append, + ..$crate::compile::CompileOptions::default() + }, + ) + }; + ($class_file:expr, $method:expr, $source:expr, no_verify) => { + $crate::compile::append_method_body( + $source, + &mut $class_file, + $method, + &$crate::compile::CompileOptions { + insert_mode: $crate::compile::InsertMode::Append, + ..$crate::compile::CompileOptions::default() + }, + ) + }; +} diff --git a/src/compile/parser.rs b/src/compile/parser.rs new file mode 100644 index 0000000..f85ffdb --- /dev/null +++ b/src/compile/parser.rs @@ -0,0 +1,1553 @@ +use super::CompileError; +use super::ast::*; +use super::lexer::{SpannedToken, Token}; + +pub struct Parser { + tokens: Vec, + pos: usize, +} + +impl Parser { + pub fn new(tokens: Vec) -> Self { + Parser { tokens, pos: 0 } + } + + fn peek(&self) -> &Token { + &self.tokens[self.pos].token + } + + fn at(&self, token: &Token) -> bool { + self.peek() == token + } + + fn advance(&mut self) -> &SpannedToken { + let t = &self.tokens[self.pos]; + if self.pos + 1 < self.tokens.len() { + self.pos += 1; + } + t + } + + fn expect(&mut self, expected: &Token) -> Result<(), CompileError> { + if self.peek() == expected { + self.advance(); + Ok(()) + } else { + Err(self.error(format!("expected {:?}, got {:?}", expected, self.peek()))) + } + } + + fn error(&self, message: impl Into) -> CompileError { + let span = &self.tokens[self.pos]; + CompileError::ParseError { + line: span.line, + column: span.column, + message: message.into(), + } + } + + fn expect_ident(&mut self) -> Result { + if let Token::Ident(name) = self.peek().clone() { + self.advance(); + Ok(name) + } else { + Err(self.error(format!("expected identifier, got {:?}", self.peek()))) + } + } + + /// Parse a method body: "{" statement* "}" + pub fn parse_method_body(&mut self) -> Result, CompileError> { + self.expect(&Token::LBrace)?; + let mut stmts = Vec::new(); + while !self.at(&Token::RBrace) && !self.at(&Token::Eof) { + stmts.push(self.parse_statement()?); + } + self.expect(&Token::RBrace)?; + Ok(stmts) + } + + fn parse_statement(&mut self) -> Result { + match self.peek() { + Token::LBrace => { + self.advance(); + let mut stmts = Vec::new(); + while !self.at(&Token::RBrace) && !self.at(&Token::Eof) { + stmts.push(self.parse_statement()?); + } + self.expect(&Token::RBrace)?; + Ok(CStmt::Block(stmts)) + } + Token::If => self.parse_if(), + Token::While => self.parse_while(), + Token::For => self.parse_for(), + Token::Switch => self.parse_switch(), + Token::Synchronized => self.parse_synchronized(), + Token::Try => self.parse_try_catch(), + Token::Return => self.parse_return(), + Token::Throw => self.parse_throw(), + Token::Break => { + self.advance(); + self.expect(&Token::Semicolon)?; + Ok(CStmt::Break) + } + Token::Continue => { + self.advance(); + self.expect(&Token::Semicolon)?; + Ok(CStmt::Continue) + } + Token::Var => { + self.advance(); + let name = self.expect_ident()?; + self.expect(&Token::Eq)?; + let init = self.parse_expression()?; + self.expect(&Token::Semicolon)?; + Ok(CStmt::LocalDecl { + ty: TypeName::Class("__var__".into()), + name, + init: Some(init), + }) + } + // Attempt type name for local declaration + Token::KwInt + | Token::KwLong + | Token::KwFloat + | Token::KwDouble + | Token::KwBoolean + | Token::KwByte + | Token::KwChar + | Token::KwShort + | Token::KwVoid => self.parse_local_decl(), + // Could be a local decl with class type or an expression statement + Token::Ident(_) => { + // Lookahead to distinguish type declaration from expression + if self.is_local_decl_start() { + self.parse_local_decl() + } else { + self.parse_expr_statement() + } + } + _ => self.parse_expr_statement(), + } + } + + /// Lookahead to determine if current position starts a local declaration. + /// Pattern: Ident ("." Ident)* ("<" ... ">")? ("[" "]")* Ident + /// vs expression: Ident followed by operator/dot-method/etc. + fn is_local_decl_start(&self) -> bool { + let mut i = self.pos; + // Must start with Ident + if !matches!(&self.tokens[i].token, Token::Ident(_)) { + return false; + } + i += 1; + // Skip dotted name: "." Ident + while i < self.tokens.len() { + if self.tokens[i].token == Token::Dot { + i += 1; + if i < self.tokens.len() && matches!(&self.tokens[i].token, Token::Ident(_)) { + i += 1; + } else { + return false; + } + } else { + break; + } + } + // Skip generic type parameters: "<" ... ">" + if i < self.tokens.len() && self.tokens[i].token == Token::Lt { + i += 1; + let mut depth: i32 = 1; + while i < self.tokens.len() && depth > 0 { + match &self.tokens[i].token { + Token::Lt => { + depth += 1; + i += 1; + } + Token::Gt => { + depth -= 1; + i += 1; + } + Token::GtGt => { + depth -= 2; + i += 1; + } + Token::Eof => break, + _ => { + i += 1; + } + } + } + } + // Skip array brackets: "[" "]" + while i + 1 < self.tokens.len() + && self.tokens[i].token == Token::LBracket + && self.tokens[i + 1].token == Token::RBracket + { + i += 2; + } + // Must be followed by an identifier (the variable name) + i < self.tokens.len() && matches!(&self.tokens[i].token, Token::Ident(_)) + } + + fn parse_if(&mut self) -> Result { + self.expect(&Token::If)?; + self.expect(&Token::LParen)?; + let condition = self.parse_expression()?; + self.expect(&Token::RParen)?; + let then_body = self.parse_block_or_single()?; + let else_body = if self.at(&Token::Else) { + self.advance(); + Some(self.parse_block_or_single()?) + } else { + None + }; + Ok(CStmt::If { + condition, + then_body, + else_body, + }) + } + + fn parse_while(&mut self) -> Result { + self.expect(&Token::While)?; + self.expect(&Token::LParen)?; + let condition = self.parse_expression()?; + self.expect(&Token::RParen)?; + let body = self.parse_block_or_single()?; + Ok(CStmt::While { condition, body }) + } + + fn parse_for(&mut self) -> Result { + self.expect(&Token::For)?; + self.expect(&Token::LParen)?; + + // Try for-each: for (Type name : expr) + if self.is_type_start() { + let saved_pos = self.pos; + if let Ok(ty) = self.parse_type_name() + && let Ok(name) = self.expect_ident() + && self.at(&Token::Colon) + { + self.advance(); // consume ':' + let iterable = self.parse_expression()?; + self.expect(&Token::RParen)?; + let body = self.parse_block_or_single()?; + return Ok(CStmt::ForEach { + element_type: ty, + var_name: name, + iterable, + body, + }); + } + // Not a for-each, restore position and parse as traditional for + self.pos = saved_pos; + } + + // Init + let init = if self.at(&Token::Semicolon) { + self.advance(); + None + } else { + let stmt = if self.is_type_start() { + self.parse_local_decl_no_semi()? + } else { + let expr = self.parse_expression()?; + CStmt::ExprStmt(expr) + }; + self.expect(&Token::Semicolon)?; + Some(Box::new(stmt)) + }; + + // Condition + let condition = if self.at(&Token::Semicolon) { + None + } else { + Some(self.parse_expression()?) + }; + self.expect(&Token::Semicolon)?; + + // Update + let update = if self.at(&Token::RParen) { + None + } else { + let expr = self.parse_expression()?; + Some(Box::new(CStmt::ExprStmt(expr))) + }; + self.expect(&Token::RParen)?; + + let body = self.parse_block_or_single()?; + Ok(CStmt::For { + init, + condition, + update, + body, + }) + } + + fn parse_return(&mut self) -> Result { + self.expect(&Token::Return)?; + if self.at(&Token::Semicolon) { + self.advance(); + Ok(CStmt::Return(None)) + } else { + let expr = self.parse_expression()?; + self.expect(&Token::Semicolon)?; + Ok(CStmt::Return(Some(expr))) + } + } + + fn parse_throw(&mut self) -> Result { + self.expect(&Token::Throw)?; + let expr = self.parse_expression()?; + self.expect(&Token::Semicolon)?; + Ok(CStmt::Throw(expr)) + } + + fn parse_switch(&mut self) -> Result { + self.expect(&Token::Switch)?; + self.expect(&Token::LParen)?; + let expr = self.parse_expression()?; + self.expect(&Token::RParen)?; + self.expect(&Token::LBrace)?; + + let mut cases: Vec = Vec::new(); + let mut default_body: Option> = None; + // Accumulate consecutive case labels before a body + let mut pending_values: Vec = Vec::new(); + + while !self.at(&Token::RBrace) && !self.at(&Token::Eof) { + if self.at(&Token::Case) { + self.advance(); + let value = self.parse_case_value()?; + self.expect(&Token::Colon)?; + pending_values.push(value); + } else if self.at(&Token::Default) { + self.advance(); + self.expect(&Token::Colon)?; + // Collect the body for default + let mut body = Vec::new(); + while !self.at(&Token::RBrace) + && !self.at(&Token::Case) + && !self.at(&Token::Default) + && !self.at(&Token::Eof) + { + body.push(self.parse_statement()?); + } + default_body = Some(body); + } else { + // This is a statement that belongs to the most recent case/default label(s) + let mut body = Vec::new(); + while !self.at(&Token::RBrace) + && !self.at(&Token::Case) + && !self.at(&Token::Default) + && !self.at(&Token::Eof) + { + body.push(self.parse_statement()?); + } + if !pending_values.is_empty() { + cases.push(SwitchCase { + values: std::mem::take(&mut pending_values), + body, + }); + } else { + return Err(self.error("unexpected statement outside case/default in switch")); + } + } + } + // If there are pending values without a body (fall-through to end), add empty case + if !pending_values.is_empty() { + cases.push(SwitchCase { + values: pending_values, + body: Vec::new(), + }); + } + self.expect(&Token::RBrace)?; + Ok(CStmt::Switch { + expr, + cases, + default_body, + }) + } + + fn parse_case_value(&mut self) -> Result { + let negative = if self.at(&Token::Minus) { + self.advance(); + true + } else { + false + }; + match self.peek().clone() { + Token::IntLiteral(v) => { + self.advance(); + Ok(if negative { -v } else { v }) + } + Token::LongLiteral(v) => { + self.advance(); + Ok(if negative { -v } else { v }) + } + _ => Err(self.error(format!( + "expected integer case value, got {:?}", + self.peek() + ))), + } + } + + fn parse_try_catch(&mut self) -> Result { + self.expect(&Token::Try)?; + let try_body = self.parse_block_or_single()?; + + let mut catches = Vec::new(); + let mut finally_body = None; + + while self.at(&Token::Catch) { + self.advance(); + self.expect(&Token::LParen)?; + let mut exception_types = vec![self.parse_type_name()?]; + while self.at(&Token::Pipe) { + self.advance(); + exception_types.push(self.parse_type_name()?); + } + let var_name = self.expect_ident()?; + self.expect(&Token::RParen)?; + let body = self.parse_block_or_single()?; + catches.push(CatchClause { + exception_types, + var_name, + body, + }); + } + + if self.at(&Token::Finally) { + self.advance(); + finally_body = Some(self.parse_block_or_single()?); + } + + if catches.is_empty() && finally_body.is_none() { + return Err(self.error("try requires at least one catch or finally block")); + } + + Ok(CStmt::TryCatch { + try_body, + catches, + finally_body, + }) + } + + fn parse_synchronized(&mut self) -> Result { + self.expect(&Token::Synchronized)?; + self.expect(&Token::LParen)?; + let lock_expr = self.parse_expression()?; + self.expect(&Token::RParen)?; + let body = self.parse_block_or_single()?; + Ok(CStmt::Synchronized { lock_expr, body }) + } + + fn parse_local_decl(&mut self) -> Result { + let stmt = self.parse_local_decl_no_semi()?; + self.expect(&Token::Semicolon)?; + Ok(stmt) + } + + fn parse_local_decl_no_semi(&mut self) -> Result { + let ty = self.parse_type_name()?; + let name = self.expect_ident()?; + let init = if self.at(&Token::Eq) { + self.advance(); + Some(self.parse_expression()?) + } else { + None + }; + Ok(CStmt::LocalDecl { ty, name, init }) + } + + fn parse_expr_statement(&mut self) -> Result { + let expr = self.parse_expression()?; + self.expect(&Token::Semicolon)?; + Ok(CStmt::ExprStmt(expr)) + } + + fn parse_block_or_single(&mut self) -> Result, CompileError> { + if self.at(&Token::LBrace) { + self.advance(); + let mut stmts = Vec::new(); + while !self.at(&Token::RBrace) && !self.at(&Token::Eof) { + stmts.push(self.parse_statement()?); + } + self.expect(&Token::RBrace)?; + Ok(stmts) + } else { + Ok(vec![self.parse_statement()?]) + } + } + + fn is_type_start(&self) -> bool { + matches!( + self.peek(), + Token::KwInt + | Token::KwLong + | Token::KwFloat + | Token::KwDouble + | Token::KwBoolean + | Token::KwByte + | Token::KwChar + | Token::KwShort + | Token::KwVoid + ) || (matches!(self.peek(), Token::Ident(_)) && self.is_local_decl_start()) + } + + fn parse_type_name(&mut self) -> Result { + let base = match self.peek() { + Token::KwInt => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Int) + } + Token::KwLong => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Long) + } + Token::KwFloat => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Float) + } + Token::KwDouble => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Double) + } + Token::KwBoolean => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Boolean) + } + Token::KwByte => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Byte) + } + Token::KwChar => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Char) + } + Token::KwShort => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Short) + } + Token::KwVoid => { + self.advance(); + TypeName::Primitive(PrimitiveKind::Void) + } + Token::Ident(_) => { + let mut name = self.expect_ident()?; + while self.at(&Token::Dot) { + // Peek ahead: if next is ident followed by another dot or bracket or ident (type context), continue + if let Token::Ident(_) = &self.tokens[self.pos + 1].token { + // Check if this is part of a dotted class name + // "Ident.Ident" in type position + self.advance(); // consume dot + let next = self.expect_ident()?; + name = format!("{}.{}", name, next); + } else { + break; + } + } + // Skip generic type parameters: List, Map, etc. (erased at compile time) + if self.at(&Token::Lt) { + self.skip_type_parameters()?; + } + TypeName::Class(name) + } + _ => return Err(self.error(format!("expected type name, got {:?}", self.peek()))), + }; + + // Handle array brackets + let mut ty = base; + while self.at(&Token::LBracket) + && self + .tokens + .get(self.pos + 1) + .is_some_and(|t| t.token == Token::RBracket) + { + self.advance(); // [ + self.advance(); // ] + ty = TypeName::Array(Box::new(ty)); + } + + Ok(ty) + } + + // --- Expression parsing with Java operator precedence --- + + fn parse_expression(&mut self) -> Result { + self.parse_assignment() + } + + fn parse_assignment(&mut self) -> Result { + let expr = self.parse_ternary()?; + + match self.peek() { + Token::Eq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::Assign { + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::PlusEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Add, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::MinusEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Sub, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::StarEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Mul, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::SlashEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Div, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::PercentEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Rem, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::AmpEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::BitAnd, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::PipeEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::BitOr, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::CaretEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::BitXor, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::LtLtEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Shl, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::GtGtEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Shr, + target: Box::new(expr), + value: Box::new(value), + }) + } + Token::GtGtGtEq => { + self.advance(); + let value = self.parse_assignment()?; + Ok(CExpr::CompoundAssign { + op: BinOp::Ushr, + target: Box::new(expr), + value: Box::new(value), + }) + } + _ => Ok(expr), + } + } + + fn parse_ternary(&mut self) -> Result { + let expr = self.parse_logical_or()?; + if self.at(&Token::Question) { + self.advance(); + let then_expr = self.parse_expression()?; + self.expect(&Token::Colon)?; + let else_expr = self.parse_ternary()?; + Ok(CExpr::Ternary { + condition: Box::new(expr), + then_expr: Box::new(then_expr), + else_expr: Box::new(else_expr), + }) + } else { + Ok(expr) + } + } + + fn parse_logical_or(&mut self) -> Result { + let mut left = self.parse_logical_and()?; + while self.at(&Token::PipePipe) { + self.advance(); + let right = self.parse_logical_and()?; + left = CExpr::LogicalOr(Box::new(left), Box::new(right)); + } + Ok(left) + } + + fn parse_logical_and(&mut self) -> Result { + let mut left = self.parse_bitwise_or()?; + while self.at(&Token::AmpAmp) { + self.advance(); + let right = self.parse_bitwise_or()?; + left = CExpr::LogicalAnd(Box::new(left), Box::new(right)); + } + Ok(left) + } + + fn parse_bitwise_or(&mut self) -> Result { + let mut left = self.parse_bitwise_xor()?; + while self.at(&Token::Pipe) { + self.advance(); + let right = self.parse_bitwise_xor()?; + left = CExpr::BinaryOp { + op: BinOp::BitOr, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + fn parse_bitwise_xor(&mut self) -> Result { + let mut left = self.parse_bitwise_and()?; + while self.at(&Token::Caret) { + self.advance(); + let right = self.parse_bitwise_and()?; + left = CExpr::BinaryOp { + op: BinOp::BitXor, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + fn parse_bitwise_and(&mut self) -> Result { + let mut left = self.parse_equality()?; + while self.at(&Token::Amp) { + self.advance(); + let right = self.parse_equality()?; + left = CExpr::BinaryOp { + op: BinOp::BitAnd, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + fn parse_equality(&mut self) -> Result { + let mut left = self.parse_relational()?; + loop { + let op = match self.peek() { + Token::EqEq => CompareOp::Eq, + Token::BangEq => CompareOp::Ne, + _ => break, + }; + self.advance(); + let right = self.parse_relational()?; + left = CExpr::Comparison { + op, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + fn parse_relational(&mut self) -> Result { + let mut left = self.parse_shift()?; + loop { + match self.peek() { + Token::Lt => { + self.advance(); + let right = self.parse_shift()?; + left = CExpr::Comparison { + op: CompareOp::Lt, + left: Box::new(left), + right: Box::new(right), + }; + } + Token::LtEq => { + self.advance(); + let right = self.parse_shift()?; + left = CExpr::Comparison { + op: CompareOp::Le, + left: Box::new(left), + right: Box::new(right), + }; + } + Token::Gt => { + self.advance(); + let right = self.parse_shift()?; + left = CExpr::Comparison { + op: CompareOp::Gt, + left: Box::new(left), + right: Box::new(right), + }; + } + Token::GtEq => { + self.advance(); + let right = self.parse_shift()?; + left = CExpr::Comparison { + op: CompareOp::Ge, + left: Box::new(left), + right: Box::new(right), + }; + } + Token::Instanceof => { + self.advance(); + let ty = self.parse_type_name()?; + left = CExpr::Instanceof { + operand: Box::new(left), + ty, + }; + } + _ => break, + } + } + Ok(left) + } + + fn parse_shift(&mut self) -> Result { + let mut left = self.parse_additive()?; + loop { + let op = match self.peek() { + Token::LtLt => BinOp::Shl, + Token::GtGt => BinOp::Shr, + Token::GtGtGt => BinOp::Ushr, + _ => break, + }; + self.advance(); + let right = self.parse_additive()?; + left = CExpr::BinaryOp { + op, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + fn parse_additive(&mut self) -> Result { + let mut left = self.parse_multiplicative()?; + loop { + let op = match self.peek() { + Token::Plus => BinOp::Add, + Token::Minus => BinOp::Sub, + _ => break, + }; + self.advance(); + let right = self.parse_multiplicative()?; + left = CExpr::BinaryOp { + op, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + fn parse_multiplicative(&mut self) -> Result { + let mut left = self.parse_unary()?; + loop { + let op = match self.peek() { + Token::Star => BinOp::Mul, + Token::Slash => BinOp::Div, + Token::Percent => BinOp::Rem, + _ => break, + }; + self.advance(); + let right = self.parse_unary()?; + left = CExpr::BinaryOp { + op, + left: Box::new(left), + right: Box::new(right), + }; + } + Ok(left) + } + + fn parse_unary(&mut self) -> Result { + match self.peek() { + Token::Minus => { + self.advance(); + // Handle negative literals directly + match self.peek() { + Token::IntLiteral(v) => { + let v = *v; + self.advance(); + Ok(CExpr::IntLiteral(-v)) + } + Token::LongLiteral(v) => { + let v = *v; + self.advance(); + Ok(CExpr::LongLiteral(-v)) + } + Token::FloatLiteral(v) => { + let v = *v; + self.advance(); + Ok(CExpr::FloatLiteral(-v)) + } + Token::DoubleLiteral(v) => { + let v = *v; + self.advance(); + Ok(CExpr::DoubleLiteral(-v)) + } + _ => { + let operand = self.parse_unary()?; + Ok(CExpr::UnaryOp { + op: UnaryOp::Neg, + operand: Box::new(operand), + }) + } + } + } + Token::Bang => { + self.advance(); + let operand = self.parse_unary()?; + Ok(CExpr::LogicalNot(Box::new(operand))) + } + Token::Tilde => { + self.advance(); + let operand = self.parse_unary()?; + Ok(CExpr::UnaryOp { + op: UnaryOp::BitNot, + operand: Box::new(operand), + }) + } + Token::PlusPlus => { + self.advance(); + let operand = self.parse_unary()?; + Ok(CExpr::PreIncrement(Box::new(operand))) + } + Token::MinusMinus => { + self.advance(); + let operand = self.parse_unary()?; + Ok(CExpr::PreDecrement(Box::new(operand))) + } + Token::LParen => { + // Could be cast or parenthesized expression + if self.is_cast() { + self.advance(); // ( + let ty = self.parse_type_name()?; + self.expect(&Token::RParen)?; + let operand = self.parse_unary()?; + Ok(CExpr::Cast { + ty, + operand: Box::new(operand), + }) + } else { + self.parse_postfix() + } + } + _ => self.parse_postfix(), + } + } + + /// Distinguish cast from parenthesized expression. + /// Cast: "(" type_name ")" unary_expr + fn is_cast(&self) -> bool { + // Check if ( is followed by a type name and then ) + let mut i = self.pos + 1; // skip ( + if i >= self.tokens.len() { + return false; + } + // Check for primitive type + if matches!( + &self.tokens[i].token, + Token::KwInt + | Token::KwLong + | Token::KwFloat + | Token::KwDouble + | Token::KwBoolean + | Token::KwByte + | Token::KwChar + | Token::KwShort + ) { + i += 1; + // skip array brackets + while i + 1 < self.tokens.len() + && self.tokens[i].token == Token::LBracket + && self.tokens[i + 1].token == Token::RBracket + { + i += 2; + } + return i < self.tokens.len() && self.tokens[i].token == Token::RParen; + } + // Check for class type cast: (ClassName) expr + // Only if the ident is likely a type (starts with uppercase) and is followed by ) + if let Token::Ident(name) = &self.tokens[i].token + && name.chars().next().is_some_and(|c| c.is_ascii_uppercase()) + { + i += 1; + // Skip dotted name + while i + 1 < self.tokens.len() + && self.tokens[i].token == Token::Dot + && matches!(&self.tokens[i + 1].token, Token::Ident(_)) + { + i += 2; + } + // skip array brackets + while i + 1 < self.tokens.len() + && self.tokens[i].token == Token::LBracket + && self.tokens[i + 1].token == Token::RBracket + { + i += 2; + } + return i < self.tokens.len() && self.tokens[i].token == Token::RParen; + } + false + } + + fn parse_postfix(&mut self) -> Result { + let mut expr = self.parse_primary()?; + + loop { + match self.peek() { + Token::Dot => { + self.advance(); + // Handle generic type parameters before method name: obj.method() + if self.at(&Token::Lt) { + self.skip_type_parameters()?; + } + let name = self.expect_ident()?; + // Also handle generics after method name: obj.method() + if self.at(&Token::Lt) { + self.skip_type_parameters()?; + } + if self.at(&Token::LParen) { + // Method call + let args = self.parse_args()?; + expr = CExpr::MethodCall { + object: Some(Box::new(expr)), + name, + args, + }; + } else { + // Field access + expr = CExpr::FieldAccess { + object: Box::new(expr), + name, + }; + } + } + Token::LBracket => { + self.advance(); + let index = self.parse_expression()?; + self.expect(&Token::RBracket)?; + expr = CExpr::ArrayAccess { + array: Box::new(expr), + index: Box::new(index), + }; + } + Token::PlusPlus => { + self.advance(); + expr = CExpr::PostIncrement(Box::new(expr)); + } + Token::MinusMinus => { + self.advance(); + expr = CExpr::PostDecrement(Box::new(expr)); + } + Token::ColonColon => { + // Method reference: expr::methodName + self.advance(); + let method_name = self.expect_ident()?; + // Extract class name from the expression + let class_name = match &expr { + CExpr::Ident(name) => name.clone(), + _ => return Err(self.error("method reference requires a class name")), + }; + expr = CExpr::MethodRef { + class_name, + method_name, + }; + } + _ => break, + } + } + + Ok(expr) + } + + fn parse_primary(&mut self) -> Result { + match self.peek().clone() { + Token::IntLiteral(v) => { + let v = v; + self.advance(); + Ok(CExpr::IntLiteral(v)) + } + Token::LongLiteral(v) => { + let v = v; + self.advance(); + Ok(CExpr::LongLiteral(v)) + } + Token::FloatLiteral(v) => { + let v = v; + self.advance(); + Ok(CExpr::FloatLiteral(v)) + } + Token::DoubleLiteral(v) => { + let v = v; + self.advance(); + Ok(CExpr::DoubleLiteral(v)) + } + Token::StringLiteral(s) => { + let s = s; + self.advance(); + Ok(CExpr::StringLiteral(s)) + } + Token::CharLiteral(c) => { + self.advance(); + Ok(CExpr::CharLiteral(c)) + } + Token::True => { + self.advance(); + Ok(CExpr::BoolLiteral(true)) + } + Token::False => { + self.advance(); + Ok(CExpr::BoolLiteral(false)) + } + Token::Null => { + self.advance(); + Ok(CExpr::NullLiteral) + } + Token::This => { + self.advance(); + Ok(CExpr::This) + } + Token::New => { + self.advance(); + let ty = self.parse_type_name()?; + if self.at(&Token::LBracket) { + // new Type[size] or new Type[size1][size2]... + self.advance(); + let first_size = self.parse_expression()?; + self.expect(&Token::RBracket)?; + + // Check for multi-dimensional: new Type[expr][expr]... + if self.at(&Token::LBracket) + && self + .tokens + .get(self.pos + 1) + .is_some_and(|t| t.token != Token::RBracket) + { + let mut dimensions = vec![first_size]; + while self.at(&Token::LBracket) + && self + .tokens + .get(self.pos + 1) + .is_some_and(|t| t.token != Token::RBracket) + { + self.advance(); // [ + dimensions.push(self.parse_expression()?); + self.expect(&Token::RBracket)?; + } + Ok(CExpr::NewMultiArray { + element_type: ty, + dimensions, + }) + } else { + Ok(CExpr::NewArray { + element_type: ty, + size: Box::new(first_size), + }) + } + } else if self.at(&Token::LParen) { + // new ClassName(args) + let class_name = match ty { + TypeName::Class(name) => name, + _ => { + return Err( + self.error("cannot use 'new' with primitive type constructor") + ); + } + }; + let args = self.parse_args()?; + Ok(CExpr::NewObject { class_name, args }) + } else { + Err(self.error("expected '(' or '[' after 'new Type'")) + } + } + Token::Switch => { + self.advance(); + self.parse_switch_expr() + } + Token::LParen => { + // Check for lambda: () -> or (Type name, ...) -> + if self.is_lambda_start() { + return self.parse_lambda(); + } + self.advance(); + let expr = self.parse_expression()?; + self.expect(&Token::RParen)?; + Ok(expr) + } + Token::Ident(name) => { + let name = name; + self.advance(); + if self.at(&Token::LParen) { + // Bare method call (unqualified) + let args = self.parse_args()?; + Ok(CExpr::MethodCall { + object: None, + name, + args, + }) + } else { + Ok(CExpr::Ident(name)) + } + } + _ => Err(self.error(format!("unexpected token: {:?}", self.peek()))), + } + } + + /// Skip generic type parameters `` in postfix position. + /// In postfix (after `.name`), `<` is unambiguous — it's always generics, not comparison. + fn skip_type_parameters(&mut self) -> Result<(), CompileError> { + self.expect(&Token::Lt)?; + let mut depth: i32 = 1; + while depth > 0 && !self.at(&Token::Eof) { + match self.peek() { + Token::Lt => { + depth += 1; + self.advance(); + } + Token::Gt => { + depth -= 1; + self.advance(); + } + Token::GtGt => { + depth -= 2; + self.advance(); + } + _ => { + self.advance(); + } + } + } + Ok(()) + } + + fn parse_switch_expr(&mut self) -> Result { + self.expect(&Token::LParen)?; + let expr = self.parse_expression()?; + self.expect(&Token::RParen)?; + self.expect(&Token::LBrace)?; + + let mut cases = Vec::new(); + let mut default_expr = None; + + while !self.at(&Token::RBrace) && !self.at(&Token::Eof) { + if self.at(&Token::Default) { + self.advance(); + self.expect(&Token::Arrow)?; + let e = self.parse_expression()?; + self.expect(&Token::Semicolon)?; + default_expr = Some(e); + } else { + self.expect(&Token::Case)?; + let mut values = Vec::new(); + values.push(self.parse_case_value()?); + while self.at(&Token::Comma) { + self.advance(); + values.push(self.parse_case_value()?); + } + self.expect(&Token::Arrow)?; + let case_expr = self.parse_expression()?; + self.expect(&Token::Semicolon)?; + cases.push(SwitchExprCase { + values, + expr: case_expr, + }); + } + } + self.expect(&Token::RBrace)?; + + let default = + default_expr.ok_or_else(|| self.error("switch expression requires a default case"))?; + + Ok(CExpr::SwitchExpr { + expr: Box::new(expr), + cases, + default_expr: Box::new(default), + }) + } + + /// Lookahead to determine if `(` starts a lambda expression. + /// Patterns: `() ->`, `(Type name) ->`, `(Type name, Type name) ->` + fn is_lambda_start(&self) -> bool { + let mut i = self.pos + 1; // skip `(` + + // () -> is a zero-arg lambda + if i < self.tokens.len() && self.tokens[i].token == Token::RParen { + return i + 1 < self.tokens.len() && self.tokens[i + 1].token == Token::Arrow; + } + + // Scan for matching `)` then check for `->` + let mut depth = 1; + while i < self.tokens.len() && depth > 0 { + match &self.tokens[i].token { + Token::LParen => depth += 1, + Token::RParen => depth -= 1, + Token::Eof => return false, + _ => {} + } + i += 1; + } + // After `)`, check for `->` + if depth == 0 && i < self.tokens.len() { + return self.tokens[i].token == Token::Arrow; + } + false + } + + fn parse_lambda(&mut self) -> Result { + self.expect(&Token::LParen)?; + let mut params = Vec::new(); + + if !self.at(&Token::RParen) { + loop { + // Try to parse a typed parameter: Type name + // Or just an identifier (inferred type) + if self.is_type_start() { + let saved = self.pos; + if let Ok(ty) = self.parse_type_name() { + if let Token::Ident(_) = self.peek() { + let name = self.expect_ident()?; + params.push(LambdaParam { ty: Some(ty), name }); + } else { + // Not Type name pattern, restore and try ident + self.pos = saved; + let name = self.expect_ident()?; + params.push(LambdaParam { ty: None, name }); + } + } else { + self.pos = saved; + let name = self.expect_ident()?; + params.push(LambdaParam { ty: None, name }); + } + } else { + let name = self.expect_ident()?; + params.push(LambdaParam { ty: None, name }); + } + if self.at(&Token::Comma) { + self.advance(); + } else { + break; + } + } + } + self.expect(&Token::RParen)?; + self.expect(&Token::Arrow)?; + + let body = if self.at(&Token::LBrace) { + let stmts = self.parse_block_or_single()?; + LambdaBody::Block(stmts) + } else { + let expr = self.parse_expression()?; + LambdaBody::Expr(Box::new(expr)) + }; + + Ok(CExpr::Lambda { params, body }) + } + + fn parse_args(&mut self) -> Result, CompileError> { + self.expect(&Token::LParen)?; + let mut args = Vec::new(); + if !self.at(&Token::RParen) { + args.push(self.parse_expression()?); + while self.at(&Token::Comma) { + self.advance(); + args.push(self.parse_expression()?); + } + } + self.expect(&Token::RParen)?; + Ok(args) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::compile::lexer::Lexer; + + fn parse(src: &str) -> Vec { + let tokens = Lexer::new(src).tokenize().unwrap(); + let mut parser = Parser::new(tokens); + parser.parse_method_body().unwrap() + } + + #[test] + fn test_return_literal() { + let stmts = parse("{ return 42; }"); + assert_eq!(stmts, vec![CStmt::Return(Some(CExpr::IntLiteral(42)))]); + } + + #[test] + fn test_local_decl() { + let stmts = parse("{ int x = 10; }"); + assert_eq!( + stmts, + vec![CStmt::LocalDecl { + ty: TypeName::Primitive(PrimitiveKind::Int), + name: "x".into(), + init: Some(CExpr::IntLiteral(10)), + }] + ); + } + + #[test] + fn test_if_else() { + let stmts = parse("{ if (x > 0) { return 1; } else { return 0; } }"); + assert!(matches!( + &stmts[0], + CStmt::If { + else_body: Some(_), + .. + } + )); + } + + #[test] + fn test_while_loop() { + let stmts = parse("{ while (i < 10) { i = i + 1; } }"); + assert!(matches!(&stmts[0], CStmt::While { .. })); + } + + #[test] + fn test_for_loop() { + let stmts = parse("{ for (int i = 0; i < 10; i++) { x = x + i; } }"); + assert!(matches!(&stmts[0], CStmt::For { .. })); + } + + #[test] + fn test_method_call() { + let stmts = parse("{ System.out.println(\"hello\"); }"); + assert!(matches!( + &stmts[0], + CStmt::ExprStmt(CExpr::MethodCall { .. }) + )); + } + + #[test] + fn test_new_object() { + let stmts = parse("{ StringBuilder sb = new StringBuilder(); }"); + assert!(matches!( + &stmts[0], + CStmt::LocalDecl { + init: Some(CExpr::NewObject { .. }), + .. + } + )); + } + + #[test] + fn test_arithmetic_precedence() { + let stmts = parse("{ return a + b * c; }"); + match &stmts[0] { + CStmt::Return(Some(CExpr::BinaryOp { + op: BinOp::Add, + right, + .. + })) => { + assert!(matches!( + right.as_ref(), + CExpr::BinaryOp { op: BinOp::Mul, .. } + )); + } + other => panic!("unexpected: {:?}", other), + } + } + + #[test] + fn test_ternary() { + let stmts = parse("{ return x > 0 ? 1 : -1; }"); + match &stmts[0] { + CStmt::Return(Some(CExpr::Ternary { .. })) => {} + other => panic!("unexpected: {:?}", other), + } + } + + #[test] + fn test_cast() { + let stmts = parse("{ long x = (long) y; }"); + match &stmts[0] { + CStmt::LocalDecl { + init: Some(CExpr::Cast { ty, .. }), + .. + } => { + assert_eq!(*ty, TypeName::Primitive(PrimitiveKind::Long)); + } + other => panic!("unexpected: {:?}", other), + } + } + + #[test] + fn test_array_access() { + let stmts = parse("{ return arr[i]; }"); + match &stmts[0] { + CStmt::Return(Some(CExpr::ArrayAccess { .. })) => {} + other => panic!("unexpected: {:?}", other), + } + } + + #[test] + fn test_class_type_decl() { + let stmts = parse("{ String s = \"hello\"; }"); + assert_eq!( + stmts, + vec![CStmt::LocalDecl { + ty: TypeName::Class("String".into()), + name: "s".into(), + init: Some(CExpr::StringLiteral("hello".into())), + }] + ); + } + + #[test] + fn test_compound_assign() { + let stmts = parse("{ x += 1; }"); + match &stmts[0] { + CStmt::ExprStmt(CExpr::CompoundAssign { op: BinOp::Add, .. }) => {} + other => panic!("unexpected: {:?}", other), + } + } +} diff --git a/src/compile/patch.rs b/src/compile/patch.rs new file mode 100644 index 0000000..94961fd --- /dev/null +++ b/src/compile/patch.rs @@ -0,0 +1,691 @@ +use crate::ClassFile; +use crate::attribute_info::{ + AttributeInfo, AttributeInfoVariant, CodeAttribute, StackMapFrame, StackMapFrameInner, + StackMapTableAttribute, +}; +use crate::code_attribute::Instruction; +use crate::constant_info::ConstantInfo; +use crate::decompile::descriptor::parse_method_descriptor; +use crate::method_info::MethodAccessFlags; + +use super::codegen::{CodeGenerator, compute_byte_addresses}; +use super::lexer::Lexer; +use super::parser::Parser; +use super::{CompileError, CompileOptions, InsertMode}; + +/// Extract parameter names from debug info (MethodParameters or LocalVariableTable). +/// +/// Returns a Vec with one entry per declared parameter. Each entry is `Some(name)` if +/// a debug name was found, or `None` otherwise. +fn extract_param_names( + class_file: &ClassFile, + method_idx: usize, + is_static: bool, + method_descriptor: &str, +) -> Vec> { + let (params, _) = match parse_method_descriptor(method_descriptor) { + Some(p) => p, + None => return vec![], + }; + let param_count = params.len(); + if param_count == 0 { + return vec![]; + } + + let method = &class_file.methods[method_idx]; + + // Try MethodParameters attribute first (method-level, available with javac -parameters) + for attr in &method.attributes { + if let Some(AttributeInfoVariant::MethodParameters(mp)) = &attr.info_parsed { + let mut names = Vec::with_capacity(param_count); + for (i, p) in mp.parameters.iter().enumerate() { + if i >= param_count { + break; + } + if p.name_index != 0 + && let Some(name) = class_file.get_utf8(p.name_index) + { + names.push(Some(name.to_string())); + continue; + } + names.push(None); + } + // Pad if MethodParameters has fewer entries + while names.len() < param_count { + names.push(None); + } + return names; + } + } + + // Fall back to LocalVariableTable (Code sub-attribute, available with javac -g) + for attr in &method.attributes { + if let Some(AttributeInfoVariant::Code(code)) = &attr.info_parsed { + for sub_attr in &code.attributes { + if let Some(AttributeInfoVariant::LocalVariableTable(lvt)) = &sub_attr.info_parsed { + // Build a slot→name map for parameter slots + // Parameter slots: if instance method, slot 0 = this, params start at 1 + let first_param_slot: u16 = if is_static { 0 } else { 1 }; + let mut slot_to_name = std::collections::HashMap::new(); + for item in &lvt.items { + // Parameters typically have start_pc == 0 + if item.start_pc == 0 + && let Some(name) = class_file.get_utf8(item.name_index) + { + slot_to_name.insert(item.index, name.to_string()); + } + } + + // Walk through params computing expected slots + let mut names = Vec::with_capacity(param_count); + let mut slot = first_param_slot; + for param in ¶ms { + names.push(slot_to_name.get(&slot).cloned()); + slot += if param.is_wide() { 2 } else { 1 }; + } + return names; + } + } + } + } + + vec![None; param_count] +} + +/// Strip trailing return instructions from generated code for prepend mode. +fn strip_trailing_returns(instructions: &mut Vec) { + while let Some(last) = instructions.last() { + match last { + Instruction::Return + | Instruction::Ireturn + | Instruction::Lreturn + | Instruction::Freturn + | Instruction::Dreturn + | Instruction::Areturn => { + instructions.pop(); + } + _ => break, + } + } +} + +/// Extract the offset_delta from a StackMapFrame. +fn frame_offset_delta(frame: &StackMapFrame) -> u16 { + match &frame.inner { + StackMapFrameInner::SameFrame {} => frame.frame_type as u16, + StackMapFrameInner::SameLocals1StackItemFrame { .. } => (frame.frame_type - 64) as u16, + StackMapFrameInner::SameLocals1StackItemFrameExtended { offset_delta, .. } + | StackMapFrameInner::ChopFrame { offset_delta, .. } + | StackMapFrameInner::SameFrameExtended { offset_delta, .. } + | StackMapFrameInner::AppendFrame { offset_delta, .. } + | StackMapFrameInner::FullFrame { offset_delta, .. } => *offset_delta, + } +} + +/// Convert StackMapTable frames from delta-encoded to absolute bytecode offsets. +fn frames_to_absolute(smt: &StackMapTableAttribute) -> Vec<(u32, StackMapFrame)> { + let mut result = Vec::new(); + let mut prev_offset: i64 = -1; + for frame in &smt.entries { + let delta = frame_offset_delta(frame) as i64; + let abs_offset = (prev_offset + delta + 1) as u32; + prev_offset = abs_offset as i64; + result.push((abs_offset, frame.clone())); + } + result +} + +/// Re-encode a frame with a new offset_delta, choosing the most compact representation. +fn reencode_frame_with_delta(frame: &StackMapFrame, new_delta: u16) -> StackMapFrame { + match &frame.inner { + StackMapFrameInner::SameFrame {} | StackMapFrameInner::SameFrameExtended { .. } => { + if new_delta <= 63 { + StackMapFrame { + frame_type: new_delta as u8, + inner: StackMapFrameInner::SameFrame {}, + } + } else { + StackMapFrame { + frame_type: 251, + inner: StackMapFrameInner::SameFrameExtended { + offset_delta: new_delta, + }, + } + } + } + StackMapFrameInner::SameLocals1StackItemFrame { stack } + | StackMapFrameInner::SameLocals1StackItemFrameExtended { stack, .. } => { + if new_delta <= 63 { + StackMapFrame { + frame_type: 64 + new_delta as u8, + inner: StackMapFrameInner::SameLocals1StackItemFrame { + stack: stack.clone(), + }, + } + } else { + StackMapFrame { + frame_type: 247, + inner: StackMapFrameInner::SameLocals1StackItemFrameExtended { + offset_delta: new_delta, + stack: stack.clone(), + }, + } + } + } + StackMapFrameInner::ChopFrame { .. } => StackMapFrame { + frame_type: frame.frame_type, + inner: StackMapFrameInner::ChopFrame { + offset_delta: new_delta, + }, + }, + StackMapFrameInner::AppendFrame { locals, .. } => StackMapFrame { + frame_type: frame.frame_type, + inner: StackMapFrameInner::AppendFrame { + offset_delta: new_delta, + locals: locals.clone(), + }, + }, + StackMapFrameInner::FullFrame { + number_of_locals, + locals, + number_of_stack_items, + stack, + .. + } => StackMapFrame { + frame_type: 255, + inner: StackMapFrameInner::FullFrame { + offset_delta: new_delta, + number_of_locals: *number_of_locals, + locals: locals.clone(), + number_of_stack_items: *number_of_stack_items, + stack: stack.clone(), + }, + }, + } +} + +/// Take frames with absolute offsets, re-compute deltas, and re-encode. +fn reencode_frames_absolute(frames: &[(u32, StackMapFrame)]) -> Vec { + let mut result = Vec::new(); + let mut prev_offset: i64 = -1; + for (abs_offset, frame) in frames { + let new_delta = (*abs_offset as i64 - prev_offset - 1) as u16; + prev_offset = *abs_offset as i64; + result.push(reencode_frame_with_delta(frame, new_delta)); + } + result +} + +/// Compile Java source and replace a method's body in the class file. +/// +/// When `method_descriptor` is `Some`, the method is matched by both name and +/// descriptor, disambiguating overloaded methods. When `None`, the first method +/// with the given name is used. +pub fn compile_method_body_impl( + source: &str, + class_file: &mut ClassFile, + method_name: &str, + method_descriptor: Option<&str>, + options: &CompileOptions, +) -> Result<(), CompileError> { + // Find the method by name (and optionally descriptor) + let method_idx = class_file + .methods + .iter() + .position(|m| { + let name_matches = matches!( + &class_file.const_pool[(m.name_index - 1) as usize], + ConstantInfo::Utf8(u) if u.utf8_string == method_name + ); + if !name_matches { + return false; + } + // If a descriptor is provided, also check it matches + if let Some(desc) = method_descriptor { + matches!( + &class_file.const_pool[(m.descriptor_index - 1) as usize], + ConstantInfo::Utf8(u) if u.utf8_string == desc + ) + } else { + true + } + }) + .ok_or_else(|| CompileError::MethodNotFound { + name: if let Some(desc) = method_descriptor { + format!("{method_name}{desc}") + } else { + method_name.to_string() + }, + })?; + + // Get method info + let is_static = class_file.methods[method_idx] + .access_flags + .contains(MethodAccessFlags::STATIC); + let descriptor_index = class_file.methods[method_idx].descriptor_index; + let method_descriptor = class_file + .get_utf8(descriptor_index) + .ok_or_else(|| CompileError::CodegenError { + message: "could not resolve method descriptor".into(), + })? + .to_string(); + + // Extract parameter names from debug info before mutably borrowing for codegen + let param_names = extract_param_names(class_file, method_idx, is_static, &method_descriptor); + + // Parse source + let lexer = Lexer::new(source); + let tokens = lexer.tokenize()?; + let mut parser = Parser::new(tokens); + let stmts = parser.parse_method_body()?; + + // Generate bytecode + let mut codegen = CodeGenerator::new_with_options( + class_file, + is_static, + &method_descriptor, + options.generate_stack_map_table, + ¶m_names, + )?; + codegen.generate_body(&stmts)?; + let generated = codegen.finish()?; + + match options.insert_mode { + InsertMode::Replace => replace_method_body(class_file, method_idx, generated, options)?, + InsertMode::Prepend => prepend_to_method_body(class_file, method_idx, generated, options)?, + InsertMode::Append => append_to_method_body(class_file, method_idx, generated, options)?, + } + + class_file.sync_counts(); + Ok(()) +} + +/// Replace the entire method body with newly generated bytecode. +fn replace_method_body( + class_file: &mut ClassFile, + method_idx: usize, + generated: super::GeneratedCode, + options: &CompileOptions, +) -> Result<(), CompileError> { + // Build StackMapTable sub-attribute if generated + let smt_sub_attr = if options.generate_stack_map_table { + if let Some(smt) = generated.stack_map_table { + let smt_name_idx = class_file.get_or_add_utf8("StackMapTable"); + let mut smt_attr = AttributeInfo { + attribute_name_index: smt_name_idx, + attribute_length: 0, + info: vec![], + info_parsed: Some(AttributeInfoVariant::StackMapTable(smt)), + }; + smt_attr + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed for StackMapTable failed: {}", e), + })?; + Some(smt_attr) + } else { + None + } + } else { + None + }; + + // Find or create Code attribute + let code_attr_idx = class_file.methods[method_idx] + .attributes + .iter() + .position(|a| matches!(a.info_parsed, Some(AttributeInfoVariant::Code(_)))); + + if let Some(attr_idx) = code_attr_idx { + let code = match &mut class_file.methods[method_idx].attributes[attr_idx].info_parsed { + Some(AttributeInfoVariant::Code(c)) => c, + _ => unreachable!(), + }; + + // Replace instructions and update stack/locals + code.code = generated.instructions; + code.max_stack = generated.max_stack; + code.max_locals = generated.max_locals; + code.exception_table = generated.exception_table; + code.exception_table_length = code.exception_table.len() as u16; + + // Strip debug and verification sub-attributes that reference old bytecode offsets + code.attributes.retain(|a| { + !matches!( + a.info_parsed, + Some(AttributeInfoVariant::LineNumberTable(_)) + | Some(AttributeInfoVariant::LocalVariableTable(_)) + | Some(AttributeInfoVariant::LocalVariableTypeTable(_)) + ) + }); + // Always strip old StackMapTable + code.attributes + .retain(|a| !matches!(a.info_parsed, Some(AttributeInfoVariant::StackMapTable(_)))); + // Attach new StackMapTable if generated + if let Some(smt_attr) = smt_sub_attr { + code.attributes.push(smt_attr); + } + code.attributes_count = code.attributes.len() as u16; + + // Sync + class_file.methods[method_idx].attributes[attr_idx] + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed failed: {}", e), + })?; + } else { + // Create a new Code attribute + let code_name_idx = class_file.get_or_add_utf8("Code"); + let exception_table_length = generated.exception_table.len() as u16; + + let mut sub_attrs = Vec::new(); + if let Some(smt_attr) = smt_sub_attr { + sub_attrs.push(smt_attr); + } + + let code_attr = CodeAttribute { + max_stack: generated.max_stack, + max_locals: generated.max_locals, + code_length: 0, // will be set by sync + code: generated.instructions, + exception_table_length, + exception_table: generated.exception_table, + attributes_count: sub_attrs.len() as u16, + attributes: sub_attrs, + }; + + let mut attr_info = AttributeInfo { + attribute_name_index: code_name_idx, + attribute_length: 0, + info: vec![], + info_parsed: Some(AttributeInfoVariant::Code(code_attr)), + }; + attr_info + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed failed: {}", e), + })?; + + class_file.methods[method_idx].attributes.push(attr_info); + class_file.methods[method_idx].attributes_count = + class_file.methods[method_idx].attributes.len() as u16; + } + + Ok(()) +} + +/// Prepend newly generated bytecode before the existing method body. +fn prepend_to_method_body( + class_file: &mut ClassFile, + method_idx: usize, + mut generated: super::GeneratedCode, + options: &CompileOptions, +) -> Result<(), CompileError> { + // Strip trailing return so prepended code falls through to original + strip_trailing_returns(&mut generated.instructions); + if generated.instructions.is_empty() { + return Ok(()); // Nothing to prepend + } + + // Find existing Code attribute (required for prepend) + let attr_idx = class_file.methods[method_idx] + .attributes + .iter() + .position(|a| matches!(a.info_parsed, Some(AttributeInfoVariant::Code(_)))) + .ok_or_else(|| CompileError::CodegenError { + message: "method has no Code attribute to prepend to".into(), + })?; + + // Pre-resolve StackMapTable name index before taking mutable borrow on code + let smt_name_idx = if options.generate_stack_map_table { + Some(class_file.get_or_add_utf8("StackMapTable")) + } else { + None + }; + + let code = match &mut class_file.methods[method_idx].attributes[attr_idx].info_parsed { + Some(AttributeInfoVariant::Code(c)) => c, + _ => unreachable!(), + }; + + // Concatenate: new instructions ++ old instructions + let old_instructions = std::mem::take(&mut code.code); + let new_count = generated.instructions.len(); + let mut combined = generated.instructions; + combined.extend(old_instructions); + + // Compute byte addresses for the combined stream + let addresses = compute_byte_addresses(&combined); + let prepend_byte_size = if new_count < addresses.len() { + addresses[new_count] + } else { + // All instructions are new (shouldn't happen, but be safe) + *addresses.last().unwrap_or(&0) + }; + + // Shift existing exception table entries by prepend_byte_size + for entry in &mut code.exception_table { + entry.start_pc += prepend_byte_size as u16; + entry.end_pc += prepend_byte_size as u16; + entry.handler_pc += prepend_byte_size as u16; + } + + // Merge exception tables: new first, then shifted old + let mut merged_exceptions = generated.exception_table; + merged_exceptions.append(&mut code.exception_table); + + // Handle StackMapTable merging + let old_smt = code.attributes.iter().find_map(|a| match &a.info_parsed { + Some(AttributeInfoVariant::StackMapTable(smt)) => Some(smt.clone()), + _ => None, + }); + + // Strip old StackMapTable and debug attributes + code.attributes.retain(|a| { + !matches!( + a.info_parsed, + Some(AttributeInfoVariant::StackMapTable(_)) + | Some(AttributeInfoVariant::LineNumberTable(_)) + | Some(AttributeInfoVariant::LocalVariableTable(_)) + | Some(AttributeInfoVariant::LocalVariableTypeTable(_)) + ) + }); + + // Build merged StackMapTable + if options.generate_stack_map_table { + let mut all_frames: Vec<(u32, StackMapFrame)> = Vec::new(); + + // Add new code's frames (already at correct absolute offsets) + if let Some(new_smt) = &generated.stack_map_table { + all_frames.extend(frames_to_absolute(new_smt)); + } + + // Add shifted old frames + if let Some(old_smt) = &old_smt { + for (offset, frame) in frames_to_absolute(old_smt) { + all_frames.push((offset + prepend_byte_size, frame)); + } + } + + if !all_frames.is_empty() { + all_frames.sort_by_key(|(offset, _)| *offset); + let reencoded = reencode_frames_absolute(&all_frames); + + let smt = StackMapTableAttribute { + number_of_entries: reencoded.len() as u16, + entries: reencoded, + }; + + let mut smt_attr = AttributeInfo { + attribute_name_index: smt_name_idx.unwrap(), + attribute_length: 0, + info: vec![], + info_parsed: Some(AttributeInfoVariant::StackMapTable(smt)), + }; + smt_attr + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed for StackMapTable failed: {}", e), + })?; + code.attributes.push(smt_attr); + } + } + + // Update CodeAttribute + code.code = combined; + code.max_stack = std::cmp::max(generated.max_stack, code.max_stack); + code.max_locals = std::cmp::max(generated.max_locals, code.max_locals); + code.exception_table = merged_exceptions; + code.exception_table_length = code.exception_table.len() as u16; + code.attributes_count = code.attributes.len() as u16; + + // Sync + class_file.methods[method_idx].attributes[attr_idx] + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed failed: {}", e), + })?; + + Ok(()) +} + +/// Append newly generated bytecode after the existing method body. +/// +/// The trailing return instruction(s) of the original method are stripped +/// so execution falls through to the appended code. The appended code +/// must contain its own return instruction. +fn append_to_method_body( + class_file: &mut ClassFile, + method_idx: usize, + generated: super::GeneratedCode, + options: &CompileOptions, +) -> Result<(), CompileError> { + if generated.instructions.is_empty() { + return Ok(()); // Nothing to append + } + + // Find existing Code attribute (required for append) + let attr_idx = class_file.methods[method_idx] + .attributes + .iter() + .position(|a| matches!(a.info_parsed, Some(AttributeInfoVariant::Code(_)))) + .ok_or_else(|| CompileError::CodegenError { + message: "method has no Code attribute to append to".into(), + })?; + + // Pre-resolve StackMapTable name index before taking mutable borrow on code + let smt_name_idx = if options.generate_stack_map_table { + Some(class_file.get_or_add_utf8("StackMapTable")) + } else { + None + }; + + let code = match &mut class_file.methods[method_idx].attributes[attr_idx].info_parsed { + Some(AttributeInfoVariant::Code(c)) => c, + _ => unreachable!(), + }; + + // Strip trailing returns from the OLD instructions so they fall through + // to the appended code. Non-trailing returns (e.g. early returns in branches) + // are left intact — those paths will skip the appended code. + strip_trailing_returns(&mut code.code); + + // Concatenate: old instructions ++ new instructions + let old_count = code.code.len(); + let mut combined = std::mem::take(&mut code.code); + combined.extend(generated.instructions); + + // Compute byte addresses for the combined stream + let addresses = compute_byte_addresses(&combined); + let old_byte_size = if old_count < addresses.len() { + addresses[old_count] + } else { + *addresses.last().unwrap_or(&0) + }; + + // Shift new exception table entries by old_byte_size + let mut new_exceptions = generated.exception_table; + for entry in &mut new_exceptions { + entry.start_pc += old_byte_size as u16; + entry.end_pc += old_byte_size as u16; + entry.handler_pc += old_byte_size as u16; + } + + // Merge exception tables: old first, then shifted new + let mut merged_exceptions = std::mem::take(&mut code.exception_table); + merged_exceptions.append(&mut new_exceptions); + + // Handle StackMapTable merging + let old_smt = code.attributes.iter().find_map(|a| match &a.info_parsed { + Some(AttributeInfoVariant::StackMapTable(smt)) => Some(smt.clone()), + _ => None, + }); + + // Strip old StackMapTable and debug attributes + code.attributes.retain(|a| { + !matches!( + a.info_parsed, + Some(AttributeInfoVariant::StackMapTable(_)) + | Some(AttributeInfoVariant::LineNumberTable(_)) + | Some(AttributeInfoVariant::LocalVariableTable(_)) + | Some(AttributeInfoVariant::LocalVariableTypeTable(_)) + ) + }); + + // Build merged StackMapTable + if options.generate_stack_map_table { + let mut all_frames: Vec<(u32, StackMapFrame)> = Vec::new(); + + // Old frames stay at their original absolute offsets + if let Some(old_smt) = &old_smt { + all_frames.extend(frames_to_absolute(old_smt)); + } + + // New frames shifted by old_byte_size + if let Some(new_smt) = &generated.stack_map_table { + for (offset, frame) in frames_to_absolute(new_smt) { + all_frames.push((offset + old_byte_size, frame)); + } + } + + if !all_frames.is_empty() { + all_frames.sort_by_key(|(offset, _)| *offset); + let reencoded = reencode_frames_absolute(&all_frames); + + let smt = StackMapTableAttribute { + number_of_entries: reencoded.len() as u16, + entries: reencoded, + }; + + let mut smt_attr = AttributeInfo { + attribute_name_index: smt_name_idx.unwrap(), + attribute_length: 0, + info: vec![], + info_parsed: Some(AttributeInfoVariant::StackMapTable(smt)), + }; + smt_attr + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed for StackMapTable failed: {}", e), + })?; + code.attributes.push(smt_attr); + } + } + + // Update CodeAttribute + code.code = combined; + code.max_stack = std::cmp::max(generated.max_stack, code.max_stack); + code.max_locals = std::cmp::max(generated.max_locals, code.max_locals); + code.exception_table = merged_exceptions; + code.exception_table_length = code.exception_table.len() as u16; + code.attributes_count = code.attributes.len() as u16; + + // Sync + class_file.methods[method_idx].attributes[attr_idx] + .sync_from_parsed() + .map_err(|e| CompileError::CodegenError { + message: format!("sync_from_parsed failed: {}", e), + })?; + + Ok(()) +} diff --git a/src/compile/stack_calc.rs b/src/compile/stack_calc.rs new file mode 100644 index 0000000..e679e6e --- /dev/null +++ b/src/compile/stack_calc.rs @@ -0,0 +1,283 @@ +use crate::code_attribute::Instruction; + +/// Compute max_stack by walking instructions and tracking stack depth. +pub fn compute_max_stack(instructions: &[Instruction]) -> u16 { + let mut depth: i32 = 0; + let mut max_depth: i32 = 0; + + for instr in instructions { + depth += stack_delta(instr); + if depth > max_depth { + max_depth = depth; + } + // Clamp to prevent underflow from unreachable code + if depth < 0 { + depth = 0; + } + } + + // Safety margin of +2: the linear walk doesn't model control flow, so it can + // underestimate the stack at branch merge points. For example, an exception handler + // pushes one value (the exception) that the linear walk doesn't see, and certain + // patterns (dup + method call) can temporarily exceed the tracked depth. The +2 + // covers these cases conservatively without requiring a full CFG analysis. + let result = max_depth + 2; + result.max(1) as u16 +} + +/// Returns the net stack depth change for an instruction. +fn stack_delta(instr: &Instruction) -> i32 { + match instr { + // Constants: push 1 + Instruction::Aconstnull + | Instruction::Iconstm1 + | Instruction::Iconst0 + | Instruction::Iconst1 + | Instruction::Iconst2 + | Instruction::Iconst3 + | Instruction::Iconst4 + | Instruction::Iconst5 + | Instruction::Fconst0 + | Instruction::Fconst1 + | Instruction::Fconst2 + | Instruction::Bipush(_) + | Instruction::Sipush(_) + | Instruction::Ldc(_) + | Instruction::LdcW(_) => 1, + + // Long/Double constants: push 2 (but treated as 1 category-2 value conceptually) + // JVM spec: long/double use 2 stack slots + Instruction::Lconst0 + | Instruction::Lconst1 + | Instruction::Dconst0 + | Instruction::Dconst1 => 2, + Instruction::Ldc2W(_) => 2, + + // Loads: push 1 (or 2 for long/double) + Instruction::Iload(_) + | Instruction::Iload0 + | Instruction::Iload1 + | Instruction::Iload2 + | Instruction::Iload3 + | Instruction::Fload(_) + | Instruction::Fload0 + | Instruction::Fload1 + | Instruction::Fload2 + | Instruction::Fload3 + | Instruction::Aload(_) + | Instruction::Aload0 + | Instruction::Aload1 + | Instruction::Aload2 + | Instruction::Aload3 + | Instruction::IloadWide(_) + | Instruction::FloadWide(_) + | Instruction::AloadWide(_) => 1, + + Instruction::Lload(_) + | Instruction::Lload0 + | Instruction::Lload1 + | Instruction::Lload2 + | Instruction::Lload3 + | Instruction::Dload(_) + | Instruction::Dload0 + | Instruction::Dload1 + | Instruction::Dload2 + | Instruction::Dload3 + | Instruction::LloadWide(_) + | Instruction::DloadWide(_) => 2, + + // Array loads: pop 2 (arrayref + index), push 1 (or 2) + Instruction::Iaload + | Instruction::Faload + | Instruction::Aaload + | Instruction::Baload + | Instruction::Caload + | Instruction::Saload => -1, // -2 + 1 + + Instruction::Laload | Instruction::Daload => 0, // -2 + 2 + + // Stores: pop 1 (or 2 for long/double) + Instruction::Istore(_) + | Instruction::Istore0 + | Instruction::Istore1 + | Instruction::Istore2 + | Instruction::Istore3 + | Instruction::Fstore(_) + | Instruction::Fstore0 + | Instruction::Fstore1 + | Instruction::Fstore2 + | Instruction::Fstore3 + | Instruction::Astore(_) + | Instruction::Astore0 + | Instruction::Astore1 + | Instruction::Astore2 + | Instruction::Astore3 + | Instruction::IstoreWide(_) + | Instruction::FstoreWide(_) + | Instruction::AstoreWide(_) => -1, + + Instruction::Lstore(_) + | Instruction::Lstore0 + | Instruction::Lstore1 + | Instruction::Lstore2 + | Instruction::Lstore3 + | Instruction::Dstore(_) + | Instruction::Dstore0 + | Instruction::Dstore1 + | Instruction::Dstore2 + | Instruction::Dstore3 + | Instruction::LstoreWide(_) + | Instruction::DstoreWide(_) => -2, + + // Array stores: pop 3 (arrayref + index + value) + Instruction::Iastore + | Instruction::Fastore + | Instruction::Aastore + | Instruction::Bastore + | Instruction::Castore + | Instruction::Sastore => -3, + + Instruction::Lastore | Instruction::Dastore => -4, // pop arrayref + index + long/double + + // Stack manipulation + Instruction::Pop => -1, + Instruction::Pop2 => -2, + Instruction::Dup => 1, + Instruction::Dupx1 => 1, + Instruction::Dupx2 => 1, + Instruction::Dup2 => 2, + Instruction::Dup2x1 => 2, + Instruction::Dup2x2 => 2, + Instruction::Swap => 0, + + // Arithmetic: pop 2, push 1 (net -1 for int/float) + Instruction::Iadd + | Instruction::Isub + | Instruction::Imul + | Instruction::Idiv + | Instruction::Irem + | Instruction::Ishl + | Instruction::Ishr + | Instruction::Iushr + | Instruction::Iand + | Instruction::Ior + | Instruction::Ixor + | Instruction::Fadd + | Instruction::Fsub + | Instruction::Fmul + | Instruction::Fdiv + | Instruction::Frem => -1, + + // Long/double arithmetic: pop 4, push 2 (net -2) + Instruction::Ladd + | Instruction::Lsub + | Instruction::Lmul + | Instruction::Ldiv + | Instruction::Lrem + | Instruction::Land + | Instruction::Lor + | Instruction::Lxor + | Instruction::Dadd + | Instruction::Dsub + | Instruction::Dmul + | Instruction::Ddiv + | Instruction::Drem => -2, + + // Long shift: pop long(2) + int(1), push long(2) = -1 + Instruction::Lshl | Instruction::Lshr | Instruction::Lushr => -1, + + // Negate: pop 1, push 1 = 0 + Instruction::Ineg | Instruction::Fneg => 0, + Instruction::Lneg | Instruction::Dneg => 0, + + // Iinc doesn't touch the stack + Instruction::Iinc { .. } | Instruction::IincWide { .. } => 0, + + // Conversions: same stack effect as source and target sizes + Instruction::I2l | Instruction::I2d | Instruction::F2l | Instruction::F2d => 1, // push extra slot + Instruction::L2i | Instruction::L2f | Instruction::D2i | Instruction::D2f => -1, // lose a slot + Instruction::I2f + | Instruction::I2b + | Instruction::I2c + | Instruction::I2s + | Instruction::F2i => 0, + Instruction::L2d | Instruction::D2l => 0, // 2 -> 2 + + // Comparisons + Instruction::Lcmp => -3, // pop 2 longs (4 slots), push int (1) = -3 + Instruction::Fcmpl | Instruction::Fcmpg => -1, // pop 2, push 1 + Instruction::Dcmpl | Instruction::Dcmpg => -3, // pop 2 doubles (4 slots), push int + + // Branches: pop operand(s), no push + Instruction::Ifeq(_) + | Instruction::Ifne(_) + | Instruction::Iflt(_) + | Instruction::Ifge(_) + | Instruction::Ifgt(_) + | Instruction::Ifle(_) + | Instruction::Ifnull(_) + | Instruction::Ifnonnull(_) => -1, + + Instruction::IfIcmpeq(_) + | Instruction::IfIcmpne(_) + | Instruction::IfIcmplt(_) + | Instruction::IfIcmpge(_) + | Instruction::IfIcmpgt(_) + | Instruction::IfIcmple(_) + | Instruction::IfAcmpeq(_) + | Instruction::IfAcmpne(_) => -2, + + Instruction::Goto(_) | Instruction::GotoW(_) => 0, + + // Returns + Instruction::Return => 0, + Instruction::Ireturn | Instruction::Freturn | Instruction::Areturn => -1, + Instruction::Lreturn | Instruction::Dreturn => -2, + + // Field access + Instruction::Getstatic(_) => 1, // push value + Instruction::Putstatic(_) => -1, // pop value + Instruction::Getfield(_) => 0, // pop objectref, push value + Instruction::Putfield(_) => -2, // pop objectref + value + + // Method invocations: complex, approximate conservatively + // For MVP, assume methods consume args and push at most 1 + Instruction::Invokevirtual(_) | Instruction::Invokespecial(_) => { + // Pops objectref + args, pushes return. Approximate: -1 (net for void methods with objectref) + // This is approximate; the actual delta depends on the method descriptor. + -1 + } + Instruction::Invokestatic(_) => { + // Pops args, pushes return. Approximate: 0 + 0 + } + Instruction::Invokeinterface { .. } => -1, + Instruction::Invokedynamic { .. } => 0, + + // Object creation + Instruction::New(_) => 1, + Instruction::Newarray(_) => 0, // pop count, push arrayref + Instruction::Anewarray(_) => 0, + Instruction::Arraylength => 0, // pop arrayref, push length + + Instruction::Athrow => -1, + Instruction::Checkcast(_) => 0, + Instruction::Instanceof(_) => 0, // pop ref, push int + + Instruction::Monitorenter | Instruction::Monitorexit => -1, + + Instruction::Multianewarray { dimensions, .. } => { + 1 - (*dimensions as i32) // pop N counts, push arrayref + } + + // Switch + Instruction::Tableswitch { .. } | Instruction::Lookupswitch { .. } => -1, + + // JSR/RET (legacy) + Instruction::Jsr(_) | Instruction::JsrW(_) => 1, + Instruction::Ret(_) | Instruction::RetWide(_) => 0, + + // Nop + Instruction::Nop => 0, + } +} diff --git a/src/compile/stackmap.rs b/src/compile/stackmap.rs new file mode 100644 index 0000000..e876ce8 --- /dev/null +++ b/src/compile/stackmap.rs @@ -0,0 +1,213 @@ +use crate::attribute_info::{ + StackMapFrame, StackMapFrameInner, StackMapTableAttribute, VerificationTypeInfo, +}; + +/// Verification type used during codegen-assisted frame tracking. +#[derive(Clone, Debug, PartialEq)] +pub enum VType { + Top, + Integer, + Float, + Long, + Double, + Null, + UninitializedThis, + Object(u16), // constant pool index for the class +} + +impl VType { + fn to_verification_type_info(&self) -> VerificationTypeInfo { + match self { + VType::Top => VerificationTypeInfo::Top, + VType::Integer => VerificationTypeInfo::Integer, + VType::Float => VerificationTypeInfo::Float, + VType::Long => VerificationTypeInfo::Long, + VType::Double => VerificationTypeInfo::Double, + VType::Null => VerificationTypeInfo::Null, + VType::UninitializedThis => VerificationTypeInfo::UninitializedThis, + VType::Object(idx) => VerificationTypeInfo::Object { class: *idx }, + } + } +} + +/// A snapshot of the frame state at a particular bytecode offset. +#[derive(Clone, Debug)] +pub struct FrameSnapshot { + pub bytecode_offset: u32, + pub locals: Vec, + pub stack: Vec, +} + +/// Tracks type state during code generation for StackMapTable building. +pub struct FrameTracker { + /// Initial locals (from method parameters). + initial_locals: Vec, + /// Recorded frame snapshots at branch targets / exception handlers. + snapshots: Vec, +} + +impl FrameTracker { + pub fn new(initial_locals: Vec) -> Self { + FrameTracker { + initial_locals, + snapshots: Vec::new(), + } + } + + /// Record a frame snapshot at the given bytecode offset. + pub fn record_frame(&mut self, offset: u32, locals: Vec, stack: Vec) { + // If a frame already exists at this offset, replace it — the last binding + // at a given offset represents the most accurate state for subsequent code. + if let Some(existing) = self + .snapshots + .iter_mut() + .find(|s| s.bytecode_offset == offset) + { + existing.locals = locals; + existing.stack = stack; + return; + } + self.snapshots.push(FrameSnapshot { + bytecode_offset: offset, + locals, + stack, + }); + } + + /// Build the StackMapTableAttribute from recorded snapshots. + pub fn build(mut self) -> Option { + if self.snapshots.is_empty() { + return None; + } + + self.snapshots.sort_by_key(|s| s.bytecode_offset); + self.snapshots.dedup_by_key(|s| s.bytecode_offset); + + let mut entries = Vec::new(); + let mut prev_offset: i64 = -1; + let mut prev_locals = self.initial_locals.clone(); + + for snapshot in &self.snapshots { + let offset_delta = (snapshot.bytecode_offset as i64 - prev_offset - 1) as u16; + prev_offset = snapshot.bytecode_offset as i64; + + let frame = encode_frame( + &prev_locals, + &snapshot.locals, + &snapshot.stack, + offset_delta, + ); + entries.push(frame); + prev_locals = snapshot.locals.clone(); + } + + Some(StackMapTableAttribute { + number_of_entries: entries.len() as u16, + entries, + }) + } +} + +/// Choose the most compact frame encoding. +/// `prev_locals` is the locals from the previous frame (or initial implicit frame). +fn encode_frame( + prev_locals: &[VType], + locals: &[VType], + stack: &[VType], + offset_delta: u16, +) -> StackMapFrame { + // SameFrame: same locals, empty stack + if stack.is_empty() && locals_match(prev_locals, locals) { + if offset_delta <= 63 { + return StackMapFrame { + frame_type: offset_delta as u8, + inner: StackMapFrameInner::SameFrame {}, + }; + } else { + return StackMapFrame { + frame_type: 251, + inner: StackMapFrameInner::SameFrameExtended { offset_delta }, + }; + } + } + + // SameLocals1StackItem: same locals, exactly 1 stack item + if stack.len() == 1 && locals_match(prev_locals, locals) { + let stack_item = stack[0].to_verification_type_info(); + if offset_delta <= 63 { + return StackMapFrame { + frame_type: 64 + offset_delta as u8, + inner: StackMapFrameInner::SameLocals1StackItemFrame { stack: stack_item }, + }; + } else { + return StackMapFrame { + frame_type: 247, + inner: StackMapFrameInner::SameLocals1StackItemFrameExtended { + offset_delta, + stack: stack_item, + }, + }; + } + } + + // AppendFrame: 1-3 new locals added, empty stack, prefix matches previous + if stack.is_empty() && locals.len() > prev_locals.len() { + let extra = locals.len() - prev_locals.len(); + if (1..=3).contains(&extra) + && locals.len() >= prev_locals.len() + && locals[..prev_locals.len()] + .iter() + .zip(prev_locals.iter()) + .all(|(a, b)| a == b) + { + let new_locals: Vec = locals[prev_locals.len()..] + .iter() + .map(|v| v.to_verification_type_info()) + .collect(); + return StackMapFrame { + frame_type: 251 + extra as u8, + inner: StackMapFrameInner::AppendFrame { + offset_delta, + locals: new_locals, + }, + }; + } + } + + // ChopFrame: 1-3 locals removed, empty stack, prefix matches previous + if stack.is_empty() && locals.len() < prev_locals.len() { + let chopped = prev_locals.len() - locals.len(); + if (1..=3).contains(&chopped) && locals.iter().zip(prev_locals.iter()).all(|(a, b)| a == b) + { + return StackMapFrame { + frame_type: (251 - chopped) as u8, + inner: StackMapFrameInner::ChopFrame { offset_delta }, + }; + } + } + + // FullFrame: complete specification + let local_vtypes: Vec = locals + .iter() + .map(|v| v.to_verification_type_info()) + .collect(); + let stack_vtypes: Vec = stack + .iter() + .map(|v| v.to_verification_type_info()) + .collect(); + + StackMapFrame { + frame_type: 255, + inner: StackMapFrameInner::FullFrame { + offset_delta, + number_of_locals: local_vtypes.len() as u16, + locals: local_vtypes, + number_of_stack_items: stack_vtypes.len() as u16, + stack: stack_vtypes, + }, + } +} + +fn locals_match(reference: &[VType], current: &[VType]) -> bool { + reference.len() == current.len() && reference.iter().zip(current.iter()).all(|(a, b)| a == b) +} diff --git a/src/constant_info/mod.rs b/src/constant_info/mod.rs index 4aebee5..6108ca4 100644 --- a/src/constant_info/mod.rs +++ b/src/constant_info/mod.rs @@ -1,5 +1,3 @@ -mod parser; mod types; -pub use self::parser::constant_parser; pub use self::types::*; diff --git a/src/constant_info/parser.rs b/src/constant_info/parser.rs deleted file mode 100644 index 7cd196b..0000000 --- a/src/constant_info/parser.rs +++ /dev/null @@ -1,187 +0,0 @@ -use crate::constant_info::*; -use nom::{ - Err, - bytes::complete::take, - combinator::map, - error::{Error, ErrorKind}, - number::complete::{be_f32, be_f64, be_i32, be_i64, be_u8, be_u16}, -}; - -fn utf8_constant(input: &[u8]) -> Utf8Constant { - let utf8_string = - cesu8::from_java_cesu8(input).unwrap_or_else(|_| String::from_utf8_lossy(input)); - Utf8Constant { - utf8_string: utf8_string.to_string().into(), - } -} - -fn const_utf8(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, length) = be_u16(input)?; - let (input, constant) = map(take(length), utf8_constant)(input)?; - Ok((input, ConstantInfo::Utf8(constant))) -} - -fn const_integer(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, value) = be_i32(input)?; - Ok((input, ConstantInfo::Integer(IntegerConstant { value }))) -} - -fn const_float(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, value) = be_f32(input)?; - Ok((input, ConstantInfo::Float(FloatConstant { value }))) -} - -fn const_long(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, value) = be_i64(input)?; - Ok((input, ConstantInfo::Long(LongConstant { value }))) -} - -fn const_double(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, value) = be_f64(input)?; - Ok((input, ConstantInfo::Double(DoubleConstant { value }))) -} - -fn const_class(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, name_index) = be_u16(input)?; - Ok((input, ConstantInfo::Class(ClassConstant { name_index }))) -} - -fn const_string(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, string_index) = be_u16(input)?; - Ok((input, ConstantInfo::String(StringConstant { string_index }))) -} - -fn const_field_ref(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, class_index) = be_u16(input)?; - let (input, name_and_type_index) = be_u16(input)?; - Ok(( - input, - ConstantInfo::FieldRef(FieldRefConstant { - class_index, - name_and_type_index, - }), - )) -} - -fn const_method_ref(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, class_index) = be_u16(input)?; - let (input, name_and_type_index) = be_u16(input)?; - Ok(( - input, - ConstantInfo::MethodRef(MethodRefConstant { - class_index, - name_and_type_index, - }), - )) -} - -fn const_interface_method_ref(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, class_index) = be_u16(input)?; - let (input, name_and_type_index) = be_u16(input)?; - Ok(( - input, - ConstantInfo::InterfaceMethodRef(InterfaceMethodRefConstant { - class_index, - name_and_type_index, - }), - )) -} - -fn const_name_and_type(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, name_index) = be_u16(input)?; - let (input, descriptor_index) = be_u16(input)?; - Ok(( - input, - ConstantInfo::NameAndType(NameAndTypeConstant { - name_index, - descriptor_index, - }), - )) -} - -fn const_method_handle(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, reference_kind) = be_u8(input)?; - let (input, reference_index) = be_u16(input)?; - Ok(( - input, - ConstantInfo::MethodHandle(MethodHandleConstant { - reference_kind, - reference_index, - }), - )) -} - -fn const_method_type(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, descriptor_index) = be_u16(input)?; - Ok(( - input, - ConstantInfo::MethodType(MethodTypeConstant { descriptor_index }), - )) -} - -fn const_invoke_dynamic(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, bootstrap_method_attr_index) = be_u16(input)?; - let (input, name_and_type_index) = be_u16(input)?; - Ok(( - input, - ConstantInfo::InvokeDynamic(InvokeDynamicConstant { - bootstrap_method_attr_index, - name_and_type_index, - }), - )) -} - -type ConstantInfoResult<'a> = Result<(&'a [u8], ConstantInfo), Err>>; -type ConstantInfoVecResult<'a> = Result<(&'a [u8], Vec), Err>>; - -fn const_block_parser(input: &[u8], const_type: u8) -> ConstantInfoResult<'_> { - match const_type { - 1 => const_utf8(input), - 3 => const_integer(input), - 4 => const_float(input), - 5 => const_long(input), - 6 => const_double(input), - 7 => const_class(input), - 8 => const_string(input), - 9 => const_field_ref(input), - 10 => const_method_ref(input), - 11 => const_interface_method_ref(input), - 12 => const_name_and_type(input), - 15 => const_method_handle(input), - 16 => const_method_type(input), - 18 => const_invoke_dynamic(input), - _ => Result::Err(Err::Error(error_position!(input, ErrorKind::Alt))), - } -} - -fn single_constant_parser(input: &[u8]) -> ConstantInfoResult<'_> { - let (input, const_type) = be_u8(input)?; - let (input, const_block) = const_block_parser(input, const_type)?; - Ok((input, const_block)) -} - -pub fn constant_parser(i: &[u8], const_pool_size: usize) -> ConstantInfoVecResult<'_> { - let mut index = 0; - let mut input = i; - let mut res = Vec::with_capacity(const_pool_size); - while index < const_pool_size { - match single_constant_parser(input) { - Ok((i, o)) => { - // Long and Double Entries have twice the size - // see https://docs.oracle.com/javase/specs/jvms/se6/html/ClassFile.doc.html#1348 - let uses_two_entries = - matches!(o, ConstantInfo::Long(..) | ConstantInfo::Double(..)); - - res.push(o); - if uses_two_entries { - res.push(ConstantInfo::Unusable); - index += 1; - } - input = i; - index += 1; - } - _ => return Result::Err(Err::Error(error_position!(input, ErrorKind::Alt))), - } - } - Ok((input, res)) -} diff --git a/src/constant_info/types.rs b/src/constant_info/types.rs index f66612b..410d5b6 100644 --- a/src/constant_info/types.rs +++ b/src/constant_info/types.rs @@ -1,29 +1,70 @@ -use binrw::{NullWideString, binrw}; +use binrw::{BinResult, binrw}; +use std::fmt::Debug; #[derive(Clone, Debug)] #[binrw] pub enum ConstantInfo { + #[brw(magic(1u8))] Utf8(Utf8Constant), + #[brw(magic(3u8))] Integer(IntegerConstant), + #[brw(magic(4u8))] Float(FloatConstant), + #[brw(magic(5u8))] Long(LongConstant), + #[brw(magic(6u8))] Double(DoubleConstant), + #[brw(magic(7u8))] Class(ClassConstant), + #[brw(magic(8u8))] String(StringConstant), + #[brw(magic(9u8))] FieldRef(FieldRefConstant), + #[brw(magic(10u8))] MethodRef(MethodRefConstant), + #[brw(magic(11u8))] InterfaceMethodRef(InterfaceMethodRefConstant), + #[brw(magic(12u8))] NameAndType(NameAndTypeConstant), + #[brw(magic(15u8))] MethodHandle(MethodHandleConstant), + #[brw(magic(16u8))] MethodType(MethodTypeConstant), + #[brw(magic(18u8))] InvokeDynamic(InvokeDynamicConstant), + #[brw(magic(19u8))] + Module(ModuleConstant), + #[brw(magic(20u8))] + Package(PackageConstant), Unusable, } +#[binrw::parser(reader)] +pub fn string_reader() -> BinResult { + let mut buf = [0u8; 2]; + reader.read_exact(&mut buf)?; + let len = u16::from_be_bytes(buf); + let mut string_bytes = vec![0; len as usize]; + let _ = reader.read_exact(&mut string_bytes); + let utf8_string = cesu8::from_java_cesu8(&string_bytes) + .unwrap_or_else(|_| String::from_utf8_lossy(&string_bytes)); + Ok(utf8_string.to_string()) +} + +#[binrw::writer(writer)] +pub fn string_writer<'a>(s: &'a String) -> BinResult<()> { + let cesu8_bytes = cesu8::to_java_cesu8(s); + writer.write_all(&u16::to_be_bytes(cesu8_bytes.len() as u16))?; + writer.write_all(&cesu8_bytes)?; + Ok(()) +} + #[derive(Clone, Debug)] #[binrw] pub struct Utf8Constant { - pub utf8_string: NullWideString, + #[br(parse_with = crate::constant_info::string_reader)] + #[bw(write_with = crate::constant_info::string_writer)] + pub utf8_string: String, // pub bytes: Vec, } @@ -110,3 +151,15 @@ pub struct InvokeDynamicConstant { pub bootstrap_method_attr_index: u16, pub name_and_type_index: u16, } + +#[derive(Clone, Debug)] +#[binrw] +pub struct ModuleConstant { + pub name_index: u16, +} + +#[derive(Clone, Debug)] +#[binrw] +pub struct PackageConstant { + pub name_index: u16, +} diff --git a/src/decompile/cfg.rs b/src/decompile/cfg.rs new file mode 100644 index 0000000..7dc1b45 --- /dev/null +++ b/src/decompile/cfg.rs @@ -0,0 +1,302 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use crate::attribute_info::CodeAttribute; +use crate::code_attribute::Instruction; + +use super::cfg_types::*; +use super::util::{compute_addresses, instruction_byte_size}; + +/// Build a control flow graph from a CodeAttribute. +pub fn build_cfg(code_attr: &CodeAttribute) -> ControlFlowGraph { + let addressed = compute_addresses(&code_attr.code); + if addressed.is_empty() { + return ControlFlowGraph { + blocks: BTreeMap::new(), + entry: 0, + exception_edges: Vec::new(), + }; + } + + // Step 1: Identify block leaders + let mut leaders = BTreeSet::new(); + leaders.insert(0u32); + + for ex in &code_attr.exception_table { + leaders.insert(ex.handler_pc as u32); + } + + for &(addr, instr) in &addressed { + let next_addr = addr + instruction_byte_size(instr, addr); + match instr { + Instruction::Goto(offset) => { + leaders.insert((addr as i64 + *offset as i64) as u32); + leaders.insert(next_addr); + } + Instruction::GotoW(offset) => { + leaders.insert((addr as i64 + *offset as i64) as u32); + leaders.insert(next_addr); + } + Instruction::Ifeq(off) + | Instruction::Ifne(off) + | Instruction::Iflt(off) + | Instruction::Ifge(off) + | Instruction::Ifgt(off) + | Instruction::Ifle(off) + | Instruction::IfIcmpeq(off) + | Instruction::IfIcmpne(off) + | Instruction::IfIcmplt(off) + | Instruction::IfIcmpge(off) + | Instruction::IfIcmpgt(off) + | Instruction::IfIcmple(off) + | Instruction::IfAcmpeq(off) + | Instruction::IfAcmpne(off) + | Instruction::Ifnull(off) + | Instruction::Ifnonnull(off) => { + leaders.insert((addr as i64 + *off as i64) as u32); + leaders.insert(next_addr); + } + Instruction::Tableswitch { + default, offsets, .. + } => { + leaders.insert((addr as i64 + *default as i64) as u32); + for off in offsets { + leaders.insert((addr as i64 + *off as i64) as u32); + } + leaders.insert(next_addr); + } + Instruction::Lookupswitch { default, pairs, .. } => { + leaders.insert((addr as i64 + *default as i64) as u32); + for (_, off) in pairs { + leaders.insert((addr as i64 + *off as i64) as u32); + } + leaders.insert(next_addr); + } + Instruction::Return + | Instruction::Ireturn + | Instruction::Lreturn + | Instruction::Freturn + | Instruction::Dreturn + | Instruction::Areturn + | Instruction::Athrow => { + leaders.insert(next_addr); + } + Instruction::Jsr(off) => { + leaders.insert((addr as i64 + *off as i64) as u32); + leaders.insert(next_addr); + } + Instruction::JsrW(off) => { + leaders.insert((addr as i64 + *off as i64) as u32); + leaders.insert(next_addr); + } + Instruction::Ret(_) | Instruction::RetWide(_) => { + leaders.insert(next_addr); + } + _ => {} + } + } + + // Step 2: Build basic blocks + let leader_vec: Vec = leaders.iter().copied().collect(); + let mut blocks = BTreeMap::new(); + let addr_to_idx: BTreeMap = addressed + .iter() + .enumerate() + .map(|(i, (a, _))| (*a, i)) + .collect(); + + for (li, &leader_addr) in leader_vec.iter().enumerate() { + if !addr_to_idx.contains_key(&leader_addr) { + continue; + } + + let start_idx = addr_to_idx[&leader_addr]; + let end_idx = if li + 1 < leader_vec.len() { + addr_to_idx + .get(&leader_vec[li + 1]) + .copied() + .unwrap_or(addressed.len()) + } else { + addressed.len() + }; + + if start_idx >= end_idx { + continue; + } + + let block_instrs: Vec = addressed[start_idx..end_idx] + .iter() + .map(|(a, i)| AddressedInstruction { + address: *a, + instruction: (*i).clone(), + }) + .collect(); + + let last = block_instrs.last().unwrap(); + let last_addr = last.address; + let last_next = last_addr + instruction_byte_size(&last.instruction, last_addr); + + let terminator = build_terminator(&last.instruction, last_addr, last_next); + + blocks.insert( + leader_addr, + BasicBlock { + id: leader_addr, + instructions: block_instrs, + terminator, + }, + ); + } + + // Step 3: Exception edges + let exception_edges: Vec = code_attr + .exception_table + .iter() + .map(|e| ExceptionEdge { + start_pc: e.start_pc, + end_pc: e.end_pc, + handler_block: e.handler_pc as u32, + catch_type: e.catch_type, + }) + .collect(); + + ControlFlowGraph { + blocks, + entry: 0, + exception_edges, + } +} + +fn build_terminator(instr: &Instruction, addr: u32, next: u32) -> Terminator { + match instr { + Instruction::Goto(off) => Terminator::Goto { + target: (addr as i64 + *off as i64) as u32, + }, + Instruction::GotoW(off) => Terminator::Goto { + target: (addr as i64 + *off as i64) as u32, + }, + Instruction::Ifeq(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntZero(CompareOp::Eq), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Ifne(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntZero(CompareOp::Ne), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Iflt(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntZero(CompareOp::Lt), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Ifge(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntZero(CompareOp::Ge), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Ifgt(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntZero(CompareOp::Gt), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Ifle(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntZero(CompareOp::Le), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfIcmpeq(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntCompare(CompareOp::Eq), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfIcmpne(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntCompare(CompareOp::Ne), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfIcmplt(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntCompare(CompareOp::Lt), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfIcmpge(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntCompare(CompareOp::Ge), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfIcmpgt(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntCompare(CompareOp::Gt), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfIcmple(off) => Terminator::ConditionalBranch { + condition: BranchCondition::IntCompare(CompareOp::Le), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfAcmpeq(off) => Terminator::ConditionalBranch { + condition: BranchCondition::RefCompare(CompareOp::Eq), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::IfAcmpne(off) => Terminator::ConditionalBranch { + condition: BranchCondition::RefCompare(CompareOp::Ne), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Ifnull(off) => Terminator::ConditionalBranch { + condition: BranchCondition::RefNull(true), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Ifnonnull(off) => Terminator::ConditionalBranch { + condition: BranchCondition::RefNull(false), + if_true: (addr as i64 + *off as i64) as u32, + if_false: next, + }, + Instruction::Tableswitch { + default, + low, + high, + offsets, + } => { + let targets: Vec = offsets + .iter() + .map(|off| (addr as i64 + *off as i64) as u32) + .collect(); + Terminator::TableSwitch { + default: (addr as i64 + *default as i64) as u32, + low: *low, + high: *high, + targets, + } + } + Instruction::Lookupswitch { default, pairs, .. } => { + let abs_pairs: Vec<(i32, u32)> = pairs + .iter() + .map(|(key, off)| (*key, (addr as i64 + *off as i64) as u32)) + .collect(); + Terminator::LookupSwitch { + default: (addr as i64 + *default as i64) as u32, + pairs: abs_pairs, + } + } + Instruction::Return + | Instruction::Ireturn + | Instruction::Lreturn + | Instruction::Freturn + | Instruction::Dreturn + | Instruction::Areturn => Terminator::Return, + Instruction::Athrow => Terminator::Throw, + Instruction::Jsr(off) => Terminator::Jsr { + target: (addr as i64 + *off as i64) as u32, + return_addr: next, + }, + Instruction::JsrW(off) => Terminator::Jsr { + target: (addr as i64 + *off as i64) as u32, + return_addr: next, + }, + Instruction::Ret(_) | Instruction::RetWide(_) => Terminator::Return, + _ => Terminator::FallThrough { target: next }, + } +} diff --git a/src/decompile/cfg_types.rs b/src/decompile/cfg_types.rs new file mode 100644 index 0000000..456f088 --- /dev/null +++ b/src/decompile/cfg_types.rs @@ -0,0 +1,169 @@ +use std::collections::BTreeMap; + +use crate::code_attribute::Instruction; + +// Re-export CompareOp from expr to avoid duplication +pub use super::expr::CompareOp; + +/// Block ID is the bytecode offset of the first instruction in the block. +pub type BlockId = u32; + +/// An instruction paired with its bytecode address. +#[derive(Clone, Debug)] +pub struct AddressedInstruction { + pub address: u32, + pub instruction: Instruction, +} + +/// How a basic block ends. +#[derive(Clone, Debug)] +pub enum Terminator { + FallThrough { + target: BlockId, + }, + Goto { + target: BlockId, + }, + ConditionalBranch { + condition: BranchCondition, + if_true: BlockId, + if_false: BlockId, + }, + TableSwitch { + default: BlockId, + low: i32, + high: i32, + targets: Vec, + }, + LookupSwitch { + default: BlockId, + pairs: Vec<(i32, BlockId)>, + }, + Return, + Throw, + Jsr { + target: BlockId, + return_addr: BlockId, + }, +} + +/// The condition for a conditional branch. +#[derive(Clone, Debug)] +pub enum BranchCondition { + IntZero(CompareOp), + IntCompare(CompareOp), + RefCompare(CompareOp), + RefNull(bool), +} + +/// A basic block in the CFG. +#[derive(Clone, Debug)] +pub struct BasicBlock { + pub id: BlockId, + pub instructions: Vec, + pub terminator: Terminator, +} + +/// An exception handler edge. +#[derive(Clone, Debug)] +pub struct ExceptionEdge { + pub start_pc: u16, + pub end_pc: u16, + pub handler_block: BlockId, + pub catch_type: u16, +} + +/// The control flow graph for a single method. +#[derive(Clone, Debug)] +pub struct ControlFlowGraph { + pub blocks: BTreeMap, + pub entry: BlockId, + pub exception_edges: Vec, +} + +impl ControlFlowGraph { + /// Get all successor block IDs for a given block. + pub fn successors(&self, block_id: BlockId) -> Vec { + match &self.blocks[&block_id].terminator { + Terminator::FallThrough { target } => vec![*target], + Terminator::Goto { target } => vec![*target], + Terminator::ConditionalBranch { + if_true, if_false, .. + } => vec![*if_true, *if_false], + Terminator::TableSwitch { + default, targets, .. + } => { + let mut succs: Vec = targets.clone(); + succs.push(*default); + succs.sort(); + succs.dedup(); + succs + } + Terminator::LookupSwitch { default, pairs, .. } => { + let mut succs: Vec = pairs.iter().map(|(_, t)| *t).collect(); + succs.push(*default); + succs.sort(); + succs.dedup(); + succs + } + Terminator::Return | Terminator::Throw => vec![], + Terminator::Jsr { + target, + return_addr, + } => vec![*target, *return_addr], + } + } + + /// Get all predecessor block IDs for a given block. + pub fn predecessors(&self, target: BlockId) -> Vec { + self.blocks + .keys() + .filter(|&&b| self.successors(b).contains(&target)) + .copied() + .collect() + } + + /// Returns block IDs in reverse postorder. + pub fn reverse_postorder(&self) -> Vec { + let mut visited = std::collections::HashSet::new(); + let mut postorder = Vec::new(); + self.dfs_postorder(self.entry, &mut visited, &mut postorder); + postorder.reverse(); + postorder + } + + fn dfs_postorder( + &self, + block: BlockId, + visited: &mut std::collections::HashSet, + postorder: &mut Vec, + ) { + if !visited.insert(block) { + return; + } + for succ in self.successors(block) { + self.dfs_postorder(succ, visited, postorder); + } + postorder.push(block); + } + + /// Generate a DOT graph for visualization. + pub fn to_dot(&self) -> String { + let mut dot = String::from("digraph CFG {\n"); + for (id, block) in &self.blocks { + let label = format!("B{} ({} instrs)", id, block.instructions.len()); + dot.push_str(&format!(" B{} [label=\"{}\"];\n", id, label)); + for succ in self.successors(*id) { + dot.push_str(&format!(" B{} -> B{};\n", id, succ)); + } + } + for edge in &self.exception_edges { + dot.push_str(&format!( + " B{} -> B{} [style=dashed, label=\"catch\"];\n", + edge.start_pc, edge.handler_block + )); + } + dot.push_str("}\n"); + dot + } +} diff --git a/src/decompile/class_decompiler.rs b/src/decompile/class_decompiler.rs new file mode 100644 index 0000000..1e6534d --- /dev/null +++ b/src/decompile/class_decompiler.rs @@ -0,0 +1,326 @@ +use std::fmt; + +use crate::attribute_info::AttributeInfoVariant; +use crate::method_info::MethodAccessFlags; +use crate::types::ClassFile; + +use super::cfg; +use super::desugar::{self, DesugarOptions}; +use super::java_ast::*; +use super::renderer::{JavaRenderer, RenderConfig}; +use super::stack_sim; +use super::structuring; +use super::type_inference; +use super::util; + +/// Error type for decompilation failures. +#[derive(Clone, Debug)] +pub enum DecompileError { + /// The class file has no methods to decompile. + NoCode, + /// A specific method failed to decompile. + MethodError { + method_name: String, + message: String, + }, + /// General error. + General(String), +} + +impl fmt::Display for DecompileError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DecompileError::NoCode => write!(f, "no code to decompile"), + DecompileError::MethodError { + method_name, + message, + } => { + write!( + f, + "failed to decompile method '{}': {}", + method_name, message + ) + } + DecompileError::General(msg) => write!(f, "{}", msg), + } + } +} + +impl std::error::Error for DecompileError {} + +/// Options controlling the decompilation process. +#[derive(Clone, Debug)] +pub struct DecompileOptions { + pub render_config: RenderConfig, + pub inline_inner_classes: bool, + pub include_synthetic: bool, + pub recover_lambdas: bool, + pub desugar_enum_switch: bool, + pub desugar_string_switch: bool, + pub desugar_try_resources: bool, + pub desugar_foreach: bool, + pub desugar_assert: bool, + pub desugar_autobox: bool, +} + +impl Default for DecompileOptions { + fn default() -> Self { + Self { + render_config: RenderConfig::default(), + inline_inner_classes: true, + include_synthetic: false, + recover_lambdas: true, + desugar_enum_switch: true, + desugar_string_switch: true, + desugar_try_resources: true, + desugar_foreach: true, + desugar_assert: true, + desugar_autobox: true, + } + } +} + +/// The main decompiler entry point. +pub struct Decompiler { + options: DecompileOptions, +} + +impl Decompiler { + pub fn new(options: DecompileOptions) -> Self { + Self { options } + } + + /// Decompile a single ClassFile to Java source. + pub fn decompile(&self, class: &ClassFile) -> Result { + let java_class = self.build_ast(class)?; + let mut config = self.options.render_config.clone(); + config.include_synthetic = self.options.include_synthetic; + let renderer = JavaRenderer::new(config); + Ok(renderer.render_class(&java_class)) + } + + /// Decompile a single method by name. + pub fn decompile_method( + &self, + class: &ClassFile, + method_name: &str, + ) -> Result { + let java_class = self.build_ast(class)?; + let method = java_class + .methods + .iter() + .find(|m| m.name == method_name) + .ok_or_else(|| { + DecompileError::General(format!("method '{}' not found", method_name)) + })?; + + let mut config = self.options.render_config.clone(); + config.include_synthetic = self.options.include_synthetic; + let renderer = JavaRenderer::new(config); + // Render just the method in a minimal class wrapper + let wrapper = JavaClass { + kind: java_class.kind.clone(), + visibility: java_class.visibility.clone(), + is_final: false, + is_abstract: false, + is_sealed: false, + is_static: false, + annotations: Vec::new(), + type_parameters: Vec::new(), + package: java_class.package.clone(), + name: java_class.name.clone(), + super_class: None, + interfaces: Vec::new(), + permitted_subclasses: Vec::new(), + record_components: Vec::new(), + fields: Vec::new(), + methods: vec![method.clone()], + inner_classes: Vec::new(), + source_file: None, + }; + Ok(renderer.render_class(&wrapper)) + } + + /// Decompile with inner classes provided. + pub fn decompile_with_inner_classes( + &self, + outer: &ClassFile, + inner: &[&ClassFile], + ) -> Result { + let mut java_class = self.build_ast(outer)?; + + // Build and attach inner classes + for inner_class in inner { + match self.build_ast(inner_class) { + Ok(mut inner_ast) => { + inner_ast.is_static = is_static_inner(inner_class); + java_class.inner_classes.push(inner_ast); + } + Err(e) => { + // Per-class error recovery: add a stub with a comment + let name = + util::get_class_name(&inner_class.const_pool, inner_class.this_class) + .unwrap_or("Unknown"); + let simple = name.rsplit('/').next().unwrap_or(name); + let simple = simple.rsplit('$').next().unwrap_or(simple); + java_class.inner_classes.push(JavaClass { + kind: ClassKind::Class, + visibility: Visibility::PackagePrivate, + is_final: false, + is_abstract: false, + is_sealed: false, + is_static: false, + annotations: Vec::new(), + type_parameters: Vec::new(), + package: None, + name: simple.to_string(), + super_class: None, + interfaces: Vec::new(), + permitted_subclasses: Vec::new(), + record_components: Vec::new(), + fields: Vec::new(), + methods: vec![JavaMethod { + visibility: Visibility::PackagePrivate, + is_static: false, + is_final: false, + is_abstract: false, + is_synchronized: false, + is_native: false, + is_default: false, + is_synthetic: false, + is_bridge: false, + annotations: Vec::new(), + type_parameters: Vec::new(), + return_type: JavaType::Void, + name: "/* error */".into(), + parameters: Vec::new(), + throws: Vec::new(), + body: None, + error: Some(format!("Decompilation failed: {}", e)), + }], + inner_classes: Vec::new(), + source_file: None, + }); + } + } + } + + let mut config = self.options.render_config.clone(); + config.include_synthetic = self.options.include_synthetic; + let renderer = JavaRenderer::new(config); + Ok(renderer.render_class(&java_class)) + } + + /// Build the Java AST from a ClassFile, decompiling all method bodies. + fn build_ast(&self, class: &ClassFile) -> Result { + let mut java_class = type_inference::build_java_class(class); + + // Decompile each method's body + for (i, method) in class.methods.iter().enumerate() { + if i >= java_class.methods.len() { + break; + } + + let method_name = util::get_utf8(&class.const_pool, method.name_index) + .unwrap_or("unknown") + .to_string(); + + // Skip abstract and native methods + if method.access_flags.contains(MethodAccessFlags::ABSTRACT) + || method.access_flags.contains(MethodAccessFlags::NATIVE) + { + continue; + } + + let code_attr = match method.code() { + Some(code) => code, + None => continue, + }; + + // Per-method error recovery: if decompilation fails, store the error + match self.decompile_method_body( + code_attr, + &class.const_pool, + method.access_flags.contains(MethodAccessFlags::STATIC), + ) { + Ok(body) => { + java_class.methods[i].body = Some(body); + } + Err(e) => { + // Build bytecode fallback comment + let mut error_msg = format!( + "Decompilation failed for method '{}': {}\nBytecode:", + method_name, e + ); + let addressed = super::util::compute_addresses(&code_attr.code); + for (addr, instr) in addressed.iter().take(20) { + error_msg.push_str(&format!("\n {:04}: {:?}", addr, instr)); + } + if addressed.len() > 20 { + error_msg.push_str(&format!( + "\n ... ({} more instructions)", + addressed.len() - 20 + )); + } + java_class.methods[i].error = Some(error_msg); + } + } + } + + Ok(java_class) + } + + fn decompile_method_body( + &self, + code_attr: &crate::attribute_info::CodeAttribute, + const_pool: &[crate::constant_info::ConstantInfo], + is_static: bool, + ) -> Result { + // Phase 1: Build CFG + let cfg = cfg::build_cfg(code_attr); + + if cfg.blocks.is_empty() { + return Ok(super::structured_types::StructuredBody::new(vec![])); + } + + // Phase 2: Stack simulation + let simulated = stack_sim::simulate_all_blocks(&cfg, const_pool, code_attr, is_static); + + // Phase 3: Control flow structuring + let mut body = structuring::structure_method(&cfg, &simulated, const_pool); + + // Phase 3b: Desugaring + let desugar_options = DesugarOptions { + foreach: self.options.desugar_foreach, + try_resources: self.options.desugar_try_resources, + enum_switch: self.options.desugar_enum_switch, + string_switch: self.options.desugar_string_switch, + assert: self.options.desugar_assert, + autobox: self.options.desugar_autobox, + synthetic_accessors: true, + }; + desugar::desugar(&mut body, &desugar_options); + + Ok(body) + } +} + +fn is_static_inner(class: &ClassFile) -> bool { + // Check InnerClasses attribute for the static flag + for attr in &class.attributes { + if let Some(AttributeInfoVariant::InnerClasses(ic)) = &attr.info_parsed { + for entry in &ic.classes { + if entry.inner_class_info_index == class.this_class { + return (entry.inner_class_access_flags & 0x0008) != 0; // ACC_STATIC + } + } + } + } + false +} + +/// Convenience function: decompile a ClassFile with default options. +pub fn decompile(class: &ClassFile) -> Result { + let decompiler = Decompiler::new(DecompileOptions::default()); + decompiler.decompile(class) +} diff --git a/src/decompile/descriptor.rs b/src/decompile/descriptor.rs new file mode 100644 index 0000000..34b1457 --- /dev/null +++ b/src/decompile/descriptor.rs @@ -0,0 +1,220 @@ +/// JVM type descriptor and method descriptor parser. + +/// Represents a JVM type from a descriptor string. +#[derive(Clone, Debug, PartialEq)] +pub enum JvmType { + Int, + Long, + Float, + Double, + Byte, + Char, + Short, + Boolean, + Void, + Reference(String), + Array(Box), + Null, + Unknown, +} + +impl JvmType { + /// Returns true if this type occupies two slots on the JVM stack. + pub fn is_wide(&self) -> bool { + matches!(self, JvmType::Long | JvmType::Double) + } + + /// Returns the JVM descriptor string for this type. + pub fn to_descriptor(&self) -> String { + match self { + JvmType::Int => "I".into(), + JvmType::Long => "J".into(), + JvmType::Float => "F".into(), + JvmType::Double => "D".into(), + JvmType::Byte => "B".into(), + JvmType::Char => "C".into(), + JvmType::Short => "S".into(), + JvmType::Boolean => "Z".into(), + JvmType::Void => "V".into(), + JvmType::Reference(name) => format!("L{};", name), + JvmType::Array(inner) => format!("[{}", inner.to_descriptor()), + JvmType::Null | JvmType::Unknown => "Ljava/lang/Object;".into(), + } + } + + /// Returns the simple (unqualified) name for display. + pub fn simple_name(&self) -> String { + match self { + JvmType::Int => "int".into(), + JvmType::Long => "long".into(), + JvmType::Float => "float".into(), + JvmType::Double => "double".into(), + JvmType::Byte => "byte".into(), + JvmType::Char => "char".into(), + JvmType::Short => "short".into(), + JvmType::Boolean => "boolean".into(), + JvmType::Void => "void".into(), + JvmType::Reference(name) => internal_to_source_name(name), + JvmType::Array(inner) => format!("{}[]", inner.simple_name()), + JvmType::Null => "null".into(), + JvmType::Unknown => "/* unknown */".into(), + } + } +} + +/// Parse a single type descriptor starting at position `pos` in `desc`. +/// Returns (JvmType, next_position). +pub fn parse_type_at(desc: &str, pos: usize) -> Option<(JvmType, usize)> { + let bytes = desc.as_bytes(); + if pos >= bytes.len() { + return None; + } + match bytes[pos] { + b'B' => Some((JvmType::Byte, pos + 1)), + b'C' => Some((JvmType::Char, pos + 1)), + b'D' => Some((JvmType::Double, pos + 1)), + b'F' => Some((JvmType::Float, pos + 1)), + b'I' => Some((JvmType::Int, pos + 1)), + b'J' => Some((JvmType::Long, pos + 1)), + b'S' => Some((JvmType::Short, pos + 1)), + b'Z' => Some((JvmType::Boolean, pos + 1)), + b'V' => Some((JvmType::Void, pos + 1)), + b'L' => { + let semi = desc[pos + 1..].find(';')?; + let class_name = &desc[pos + 1..pos + 1 + semi]; + Some(( + JvmType::Reference(class_name.to_string()), + pos + 1 + semi + 1, + )) + } + b'[' => { + let (inner, next) = parse_type_at(desc, pos + 1)?; + Some((JvmType::Array(Box::new(inner)), next)) + } + _ => None, + } +} + +/// Parse a full type descriptor string. +pub fn parse_type_descriptor(desc: &str) -> Option { + let (ty, _) = parse_type_at(desc, 0)?; + Some(ty) +} + +/// Parse a method descriptor, e.g. "(II)V" -> ([Int, Int], Void) +pub fn parse_method_descriptor(desc: &str) -> Option<(Vec, JvmType)> { + if !desc.starts_with('(') { + return None; + } + let close = desc.find(')')?; + let mut params = Vec::new(); + let mut pos = 1; + while pos < close { + let (ty, next) = parse_type_at(desc, pos)?; + params.push(ty); + pos = next; + } + let (ret, _) = parse_type_at(desc, close + 1)?; + Some((params, ret)) +} + +/// Convert internal class name to source name. +pub fn internal_to_source_name(name: &str) -> String { + name.replace('/', ".") +} + +/// Get just the simple class name from an internal name. +pub fn simple_class_name(name: &str) -> &str { + match name.rfind('/') { + Some(pos) => &name[pos + 1..], + None => name, + } +} + +/// Get the package from an internal name. +pub fn package_name(name: &str) -> Option<&str> { + match name.rfind('/') { + Some(pos) => Some(&name[..pos]), + None => None, + } +} + +/// Convert a newarray type code to JvmType. +pub fn newarray_type(atype: u8) -> JvmType { + match atype { + 4 => JvmType::Boolean, + 5 => JvmType::Char, + 6 => JvmType::Float, + 7 => JvmType::Double, + 8 => JvmType::Byte, + 9 => JvmType::Short, + 10 => JvmType::Int, + 11 => JvmType::Long, + _ => JvmType::Unknown, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_primitives() { + assert_eq!(parse_type_descriptor("I"), Some(JvmType::Int)); + assert_eq!(parse_type_descriptor("J"), Some(JvmType::Long)); + assert_eq!(parse_type_descriptor("D"), Some(JvmType::Double)); + assert_eq!(parse_type_descriptor("V"), Some(JvmType::Void)); + assert_eq!(parse_type_descriptor("Z"), Some(JvmType::Boolean)); + } + + #[test] + fn test_parse_reference() { + assert_eq!( + parse_type_descriptor("Ljava/lang/String;"), + Some(JvmType::Reference("java/lang/String".into())) + ); + } + + #[test] + fn test_parse_array() { + assert_eq!( + parse_type_descriptor("[I"), + Some(JvmType::Array(Box::new(JvmType::Int))) + ); + assert_eq!( + parse_type_descriptor("[[Ljava/lang/Object;"), + Some(JvmType::Array(Box::new(JvmType::Array(Box::new( + JvmType::Reference("java/lang/Object".into()) + ))))) + ); + } + + #[test] + fn test_parse_method_descriptor() { + let (params, ret) = parse_method_descriptor("(II)V").unwrap(); + assert_eq!(params, vec![JvmType::Int, JvmType::Int]); + assert_eq!(ret, JvmType::Void); + + let (params, ret) = parse_method_descriptor("(Ljava/lang/String;I)[B").unwrap(); + assert_eq!( + params, + vec![JvmType::Reference("java/lang/String".into()), JvmType::Int] + ); + assert_eq!(ret, JvmType::Array(Box::new(JvmType::Byte))); + + let (params, ret) = parse_method_descriptor("()V").unwrap(); + assert_eq!(params, vec![]); + assert_eq!(ret, JvmType::Void); + } + + #[test] + fn test_internal_to_source() { + assert_eq!( + internal_to_source_name("java/lang/String"), + "java.lang.String" + ); + assert_eq!(simple_class_name("java/lang/String"), "String"); + assert_eq!(package_name("java/lang/String"), Some("java/lang")); + assert_eq!(package_name("NoPackage"), None); + } +} diff --git a/src/decompile/desugar.rs b/src/decompile/desugar.rs new file mode 100644 index 0000000..d78ca35 --- /dev/null +++ b/src/decompile/desugar.rs @@ -0,0 +1,357 @@ +use super::expr::*; +use super::structured_types::*; + +/// Options controlling which desugaring passes to apply. +#[derive(Clone, Debug)] +pub struct DesugarOptions { + pub foreach: bool, + pub try_resources: bool, + pub enum_switch: bool, + pub string_switch: bool, + pub assert: bool, + pub autobox: bool, + pub synthetic_accessors: bool, +} + +impl Default for DesugarOptions { + fn default() -> Self { + Self { + foreach: true, + try_resources: true, + enum_switch: true, + string_switch: true, + assert: true, + autobox: true, + synthetic_accessors: true, + } + } +} + +/// Run all enabled desugaring passes on a structured body. +pub fn desugar(body: &mut StructuredBody, options: &DesugarOptions) { + for stmt in &mut body.statements { + desugar_stmt(stmt, options); + } +} + +fn desugar_stmt(stmt: &mut StructuredStmt, options: &DesugarOptions) { + match stmt { + StructuredStmt::Block(stmts) => { + for s in stmts.iter_mut() { + desugar_stmt(s, options); + } + if options.foreach { + desugar_foreach_in_block(stmts); + } + if options.assert { + desugar_assert_in_block(stmts); + } + } + StructuredStmt::If { + then_body, + else_body, + condition, + .. + } => { + desugar_stmt(then_body, options); + if let Some(eb) = else_body { + desugar_stmt(eb, options); + } + if options.autobox { + desugar_autobox_expr(condition); + } + } + StructuredStmt::While { + body, condition, .. + } => { + desugar_stmt(body, options); + if options.autobox { + desugar_autobox_expr(condition); + } + } + StructuredStmt::DoWhile { + body, condition, .. + } => { + desugar_stmt(body, options); + if options.autobox { + desugar_autobox_expr(condition); + } + } + StructuredStmt::For { + init, + body, + update, + condition, + .. + } => { + if let Some(i) = init { + desugar_stmt(i, options); + } + desugar_stmt(body, options); + if let Some(u) = update { + desugar_stmt(u, options); + } + if options.autobox { + desugar_autobox_expr(condition); + } + } + StructuredStmt::ForEach { body, .. } => { + desugar_stmt(body, options); + } + StructuredStmt::Switch { + cases, + default, + expr, + .. + } => { + for case in cases.iter_mut() { + desugar_stmt(&mut case.body, options); + } + if let Some(d) = default { + desugar_stmt(d, options); + } + if options.autobox { + desugar_autobox_expr(expr); + } + } + StructuredStmt::TryCatch { + try_body, + catches, + finally_body, + .. + } => { + desugar_stmt(try_body, options); + for c in catches.iter_mut() { + desugar_stmt(&mut c.body, options); + } + if let Some(f) = finally_body { + desugar_stmt(f, options); + } + } + StructuredStmt::TryWithResources { body, catches, .. } => { + desugar_stmt(body, options); + for c in catches.iter_mut() { + desugar_stmt(&mut c.body, options); + } + } + StructuredStmt::Synchronized { body, .. } => { + desugar_stmt(body, options); + } + StructuredStmt::Labeled { body, .. } => { + desugar_stmt(body, options); + } + StructuredStmt::Simple(s) => { + if options.autobox { + desugar_autobox_stmt(s); + } + } + _ => {} + } +} + +/// Detect Iterator-based for-each pattern in a block: +/// ```java +/// Iterator iter = coll.iterator(); +/// while (iter.hasNext()) { T x = (T) iter.next(); ... } +/// ``` +/// Rewrites to ForEach. +fn desugar_foreach_in_block(stmts: &mut Vec) { + let mut i = 0; + while i + 1 < stmts.len() { + let is_foreach = { + if let ( + StructuredStmt::Simple(Stmt::LocalStore { + var: iter_var, + value: iter_init, + }), + StructuredStmt::While { condition, body }, + ) = (&stmts[i], &stmts[i + 1]) + { + is_iterator_call(iter_init) + && is_has_next_call(condition, iter_var) + && find_next_call_in_body(body, iter_var).is_some() + } else { + false + } + }; + + if is_foreach + && let StructuredStmt::Simple(Stmt::LocalStore { + value: iter_init, .. + }) = &stmts[i] + { + let iterable = extract_iterator_receiver(iter_init) + .unwrap_or_else(|| Expr::Unresolved("/* iterable */".into())); + + if let StructuredStmt::While { body, .. } = &stmts[i + 1] + && let Some((loop_var, remaining_body)) = find_next_call_in_body( + body, + &LocalVar { + index: 0, + name: None, + ty: super::descriptor::JvmType::Unknown, + }, + ) + { + let foreach = StructuredStmt::ForEach { + var: loop_var, + iterable, + body: Box::new(remaining_body), + }; + stmts.splice(i..=i + 1, std::iter::once(foreach)); + continue; + } + } + i += 1; + } +} + +fn is_iterator_call(expr: &Expr) -> bool { + matches!(expr, Expr::MethodCall { method_name, .. } if method_name == "iterator") +} + +fn is_has_next_call(expr: &Expr, _iter_var: &LocalVar) -> bool { + matches!(expr, Expr::MethodCall { method_name, .. } if method_name == "hasNext") +} + +fn extract_iterator_receiver(expr: &Expr) -> Option { + if let Expr::MethodCall { + object: Some(obj), + method_name, + .. + } = expr + && method_name == "iterator" + { + return Some(*obj.clone()); + } + None +} + +fn find_next_call_in_body( + _body: &StructuredStmt, + _iter_var: &LocalVar, +) -> Option<(LocalVar, StructuredStmt)> { + // TODO: Look for `T x = (T) iter.next()` as the first statement of the body + None +} + +/// Detect `if (!$assertionsDisabled && !cond) throw new AssertionError(msg)` pattern. +fn desugar_assert_in_block(stmts: &mut Vec) { + let mut i = 0; + while i < stmts.len() { + let replacement = match &stmts[i] { + StructuredStmt::If { + condition, + then_body, + else_body: None, + } => { + if let Some((assert_cond, assert_msg)) = match_assert_pattern(condition, then_body) + { + Some(StructuredStmt::Assert { + condition: assert_cond, + message: assert_msg, + }) + } else { + None + } + } + _ => None, + }; + + if let Some(repl) = replacement { + stmts[i] = repl; + } + i += 1; + } +} + +fn match_assert_pattern( + _condition: &Expr, + _then_body: &StructuredStmt, +) -> Option<(Expr, Option)> { + // TODO: Match the pattern: + // condition = !$assertionsDisabled (a FieldGet for a static boolean field named "$assertionsDisabled") + // combined with the actual assertion condition + // then_body = throw new AssertionError(...) + None +} + +/// Desugar autoboxing/unboxing in expressions. +fn desugar_autobox_expr(expr: &mut Expr) { + *expr = desugar_autobox_inner(expr.clone()); +} + +fn desugar_autobox_inner(expr: Expr) -> Expr { + match expr { + // Integer.valueOf(n) -> n + Expr::MethodCall { + kind: InvokeKind::Static, + ref class_name, + ref method_name, + ref args, + .. + } if method_name == "valueOf" && is_wrapper_class(class_name) && args.len() == 1 => { + desugar_autobox_inner(args[0].clone()) + } + // n.intValue() / n.longValue() / etc -> n + Expr::MethodCall { + ref object, + ref method_name, + ref args, + .. + } if is_unbox_method(method_name) && args.is_empty() && object.is_some() => { + desugar_autobox_inner(*object.as_ref().unwrap().clone()) + } + // Recurse into sub-expressions + Expr::BinaryOp { op, left, right } => Expr::BinaryOp { + op, + left: Box::new(desugar_autobox_inner(*left)), + right: Box::new(desugar_autobox_inner(*right)), + }, + Expr::UnaryOp { op, operand } => Expr::UnaryOp { + op, + operand: Box::new(desugar_autobox_inner(*operand)), + }, + other => other, + } +} + +fn desugar_autobox_stmt(stmt: &mut Stmt) { + match stmt { + Stmt::LocalStore { value, .. } => *value = desugar_autobox_inner(value.clone()), + Stmt::FieldStore { value, .. } => *value = desugar_autobox_inner(value.clone()), + Stmt::ArrayStore { value, .. } => *value = desugar_autobox_inner(value.clone()), + Stmt::ExprStmt(e) => *e = desugar_autobox_inner(e.clone()), + Stmt::Return(Some(e)) => *e = desugar_autobox_inner(e.clone()), + Stmt::Throw(e) => *e = desugar_autobox_inner(e.clone()), + _ => {} + } +} + +fn is_wrapper_class(name: &str) -> bool { + matches!( + name, + "java/lang/Integer" + | "java/lang/Long" + | "java/lang/Float" + | "java/lang/Double" + | "java/lang/Byte" + | "java/lang/Short" + | "java/lang/Character" + | "java/lang/Boolean" + ) +} + +fn is_unbox_method(name: &str) -> bool { + matches!( + name, + "intValue" + | "longValue" + | "floatValue" + | "doubleValue" + | "byteValue" + | "shortValue" + | "charValue" + | "booleanValue" + ) +} diff --git a/src/decompile/expr.rs b/src/decompile/expr.rs new file mode 100644 index 0000000..e05add3 --- /dev/null +++ b/src/decompile/expr.rs @@ -0,0 +1,248 @@ +use super::cfg_types::BlockId; +use super::descriptor::JvmType; + +/// Binary operators. +#[derive(Clone, Debug, PartialEq)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Rem, + Shl, + Shr, + Ushr, + And, + Or, + Xor, +} + +/// Unary operators. +#[derive(Clone, Debug, PartialEq)] +pub enum UnaryOp { + Neg, + Not, // bitwise not (for boolean negation in conditions) +} + +/// Comparison operators. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CompareOp { + Eq, + Ne, + Lt, + Ge, + Gt, + Le, +} + +impl CompareOp { + /// Returns the negated comparison. + pub fn negate(self) -> Self { + match self { + CompareOp::Eq => CompareOp::Ne, + CompareOp::Ne => CompareOp::Eq, + CompareOp::Lt => CompareOp::Ge, + CompareOp::Ge => CompareOp::Lt, + CompareOp::Gt => CompareOp::Le, + CompareOp::Le => CompareOp::Gt, + } + } + + /// Java source token for this operator. + pub fn as_str(&self) -> &'static str { + match self { + CompareOp::Eq => "==", + CompareOp::Ne => "!=", + CompareOp::Lt => "<", + CompareOp::Ge => ">=", + CompareOp::Gt => ">", + CompareOp::Le => "<=", + } + } +} + +/// Method invocation kind. +#[derive(Clone, Debug, PartialEq)] +pub enum InvokeKind { + Virtual, + Special, + Static, + Interface, +} + +/// Local variable reference. +#[derive(Clone, Debug, PartialEq)] +pub struct LocalVar { + pub index: u16, + pub name: Option, + pub ty: JvmType, +} + +/// Expression tree node -- represents a value-producing computation. +#[derive(Clone, Debug)] +pub enum Expr { + // --- Literals --- + IntLiteral(i32), + LongLiteral(i64), + FloatLiteral(f32), + DoubleLiteral(f64), + StringLiteral(String), + ClassLiteral(String), + NullLiteral, + + // --- Variables --- + LocalLoad(LocalVar), + This, + + // --- Operations --- + BinaryOp { + op: BinOp, + left: Box, + right: Box, + }, + UnaryOp { + op: UnaryOp, + operand: Box, + }, + Cast { + target_type: JvmType, + operand: Box, + }, + Instanceof { + operand: Box, + check_type: String, + }, + + // --- Field access --- + FieldGet { + object: Option>, + class_name: String, + field_name: String, + field_type: JvmType, + }, + + // --- Method invocation --- + MethodCall { + kind: InvokeKind, + object: Option>, + class_name: String, + method_name: String, + descriptor: String, + args: Vec, + return_type: JvmType, + }, + + // --- Object creation --- + New { + class_name: String, + constructor_descriptor: String, + args: Vec, + }, + NewArray { + element_type: JvmType, + length: Box, + }, + NewMultiArray { + element_type: JvmType, + dimensions: Vec, + }, + ArrayLength { + array: Box, + }, + ArrayLoad { + array: Box, + index: Box, + element_type: JvmType, + }, + + // --- Comparison --- + Compare { + op: CompareOp, + left: Box, + right: Box, + }, + /// Result of lcmp/fcmpl/fcmpg/dcmpl/dcmpg: -1, 0, or 1 + CmpResult { + kind: CmpKind, + left: Box, + right: Box, + }, + + // --- invokedynamic (lambdas) --- + InvokeDynamic { + bootstrap_index: u16, + method_name: String, + descriptor: String, + captures: Vec, + }, + + // --- Ternary (synthesized during structuring) --- + Ternary { + condition: Box, + then_expr: Box, + else_expr: Box, + }, + + // --- Fallback --- + Unresolved(String), + + // --- Stack bookkeeping (used during simulation, cleaned up after) --- + Dup(Box), + /// Marker for an uninitialized `new` before is called + UninitNew { + class_name: String, + }, +} + +/// Compare instruction kinds (for lcmp, fcmpl, etc.) +#[derive(Clone, Debug, PartialEq)] +pub enum CmpKind { + LCmp, + FCmpL, + FCmpG, + DCmpL, + DCmpG, +} + +/// Statement -- represents a side-effecting operation. +#[derive(Clone, Debug)] +pub enum Stmt { + LocalStore { + var: LocalVar, + value: Expr, + }, + FieldStore { + object: Option, + class_name: String, + field_name: String, + field_type: JvmType, + value: Expr, + }, + ArrayStore { + array: Expr, + index: Expr, + value: Expr, + }, + ExprStmt(Expr), + Iinc { + var: LocalVar, + amount: i32, + }, + Return(Option), + Throw(Expr), + Monitor { + enter: bool, + object: Expr, + }, +} + +/// A simulated basic block: the result of stack-simulating one BasicBlock. +#[derive(Clone, Debug)] +pub struct SimulatedBlock { + pub id: BlockId, + pub statements: Vec, + pub exit_stack: Vec, + pub terminator: super::cfg_types::Terminator, + /// Branch condition expression (populated for ConditionalBranch terminators) + pub branch_condition: Option, +} diff --git a/src/decompile/java_ast.rs b/src/decompile/java_ast.rs new file mode 100644 index 0000000..8494e89 --- /dev/null +++ b/src/decompile/java_ast.rs @@ -0,0 +1,223 @@ +use super::expr::Expr; +use super::structured_types::StructuredBody; + +/// Primitive types in Java. +#[derive(Clone, Debug, PartialEq)] +pub enum PrimitiveType { + Boolean, + Byte, + Char, + Short, + Int, + Long, + Float, + Double, +} + +/// A Java type as it appears in source code (with generics). +#[derive(Clone, Debug, PartialEq)] +pub enum JavaType { + Primitive(PrimitiveType), + ClassType { + package: Option, + name: String, + type_args: Vec, + }, + ArrayType(Box), + WildcardType { + bound: Option>, + is_upper: bool, + }, + TypeVariable(String), + Void, +} + +impl JavaType { + /// Get the simple display name for this type. + pub fn display_name(&self) -> String { + match self { + JavaType::Primitive(p) => match p { + PrimitiveType::Boolean => "boolean".into(), + PrimitiveType::Byte => "byte".into(), + PrimitiveType::Char => "char".into(), + PrimitiveType::Short => "short".into(), + PrimitiveType::Int => "int".into(), + PrimitiveType::Long => "long".into(), + PrimitiveType::Float => "float".into(), + PrimitiveType::Double => "double".into(), + }, + JavaType::ClassType { + name, type_args, .. + } => { + if type_args.is_empty() { + name.clone() + } else { + let args: Vec = type_args.iter().map(|a| a.display_name()).collect(); + format!("{}<{}>", name, args.join(", ")) + } + } + JavaType::ArrayType(inner) => format!("{}[]", inner.display_name()), + JavaType::WildcardType { bound, is_upper } => match bound { + Some(b) => { + if *is_upper { + format!("? extends {}", b.display_name()) + } else { + format!("? super {}", b.display_name()) + } + } + None => "?".into(), + }, + JavaType::TypeVariable(name) => name.clone(), + JavaType::Void => "void".into(), + } + } + + /// Check if this is the java.lang.Object type. + pub fn is_object(&self) -> bool { + matches!(self, JavaType::ClassType { name, .. } if name == "Object") + } +} + +/// Visibility level. +#[derive(Clone, Debug, PartialEq)] +pub enum Visibility { + Public, + Protected, + PackagePrivate, + Private, +} + +/// What kind of class-like entity this is. +#[derive(Clone, Debug, PartialEq)] +pub enum ClassKind { + Class, + Interface, + Enum, + Annotation, + Record, +} + +/// A generic type parameter declaration. +#[derive(Clone, Debug)] +pub struct TypeParameter { + pub name: String, + pub bounds: Vec, +} + +/// A Java annotation usage. +#[derive(Clone, Debug)] +pub struct JavaAnnotation { + pub type_name: String, + pub arguments: Vec, +} + +/// An annotation argument. +#[derive(Clone, Debug)] +pub enum AnnotationArgument { + /// Named argument: `@Foo(name = value)` + Named { + name: String, + value: AnnotationValue, + }, + /// Unnamed (single-element): `@Foo(value)` + Unnamed(AnnotationValue), +} + +/// An annotation element value. +#[derive(Clone, Debug)] +pub enum AnnotationValue { + IntLiteral(i32), + LongLiteral(i64), + FloatLiteral(f32), + DoubleLiteral(f64), + StringLiteral(String), + BooleanLiteral(bool), + CharLiteral(char), + ClassLiteral(String), + EnumConstant { + type_name: String, + const_name: String, + }, + AnnotationLiteral(JavaAnnotation), + ArrayLiteral(Vec), +} + +/// A Java class / interface / enum / record / annotation. +#[derive(Clone, Debug)] +pub struct JavaClass { + pub kind: ClassKind, + pub visibility: Visibility, + pub is_final: bool, + pub is_abstract: bool, + pub is_sealed: bool, + pub is_static: bool, + pub annotations: Vec, + pub type_parameters: Vec, + pub package: Option, + pub name: String, + pub super_class: Option, + pub interfaces: Vec, + pub permitted_subclasses: Vec, + pub record_components: Vec, + pub fields: Vec, + pub methods: Vec, + pub inner_classes: Vec, + pub source_file: Option, +} + +/// A record component declaration. +#[derive(Clone, Debug)] +pub struct RecordComponent { + pub annotations: Vec, + pub component_type: JavaType, + pub name: String, +} + +/// A method parameter declaration. +#[derive(Clone, Debug)] +pub struct JavaParameter { + pub annotations: Vec, + pub param_type: JavaType, + pub name: String, + pub is_final: bool, + pub is_varargs: bool, +} + +/// A Java method declaration. +#[derive(Clone, Debug)] +pub struct JavaMethod { + pub visibility: Visibility, + pub is_static: bool, + pub is_final: bool, + pub is_abstract: bool, + pub is_synchronized: bool, + pub is_native: bool, + pub is_default: bool, + pub is_synthetic: bool, + pub is_bridge: bool, + pub annotations: Vec, + pub type_parameters: Vec, + pub return_type: JavaType, + pub name: String, + pub parameters: Vec, + pub throws: Vec, + pub body: Option, + /// If decompilation failed, this holds the error message and bytecode fallback. + pub error: Option, +} + +/// A Java field declaration. +#[derive(Clone, Debug)] +pub struct JavaField { + pub visibility: Visibility, + pub is_static: bool, + pub is_final: bool, + pub is_volatile: bool, + pub is_transient: bool, + pub is_synthetic: bool, + pub is_enum_constant: bool, + pub annotations: Vec, + pub field_type: JavaType, + pub name: String, + pub initializer: Option, +} diff --git a/src/decompile/mod.rs b/src/decompile/mod.rs new file mode 100644 index 0000000..a507292 --- /dev/null +++ b/src/decompile/mod.rs @@ -0,0 +1,20 @@ +pub mod cfg; +pub mod cfg_types; +pub mod class_decompiler; +pub mod descriptor; +pub mod desugar; +pub mod expr; +pub mod java_ast; +pub mod renderer; +pub mod stack_sim; +pub mod structured_types; +pub mod structuring; +pub mod type_inference; +pub mod util; + +pub use cfg_types::ControlFlowGraph; +pub use class_decompiler::{DecompileError, DecompileOptions, Decompiler, decompile}; +pub use expr::{BinOp, CompareOp, Expr, InvokeKind, Stmt, UnaryOp}; +pub use java_ast::*; +pub use renderer::RenderConfig; +pub use structured_types::StructuredStmt; diff --git a/src/decompile/renderer.rs b/src/decompile/renderer.rs new file mode 100644 index 0000000..9101c28 --- /dev/null +++ b/src/decompile/renderer.rs @@ -0,0 +1,1193 @@ +use std::collections::BTreeSet; +use std::fmt::Write; + +use super::descriptor; +use super::expr::*; +use super::java_ast::*; +use super::structured_types::*; + +/// Configuration for rendering Java source code. +#[derive(Clone, Debug)] +pub struct RenderConfig { + pub indent: String, + pub max_line_width: usize, + pub use_var: bool, + pub include_synthetic: bool, +} + +impl Default for RenderConfig { + fn default() -> Self { + Self { + indent: " ".into(), + max_line_width: 120, + use_var: false, + include_synthetic: false, + } + } +} + +/// Java source code renderer. +pub struct JavaRenderer { + config: RenderConfig, + imports: BTreeSet, + output: String, + indent_level: usize, +} + +impl JavaRenderer { + pub fn new(config: RenderConfig) -> Self { + Self { + config, + imports: BTreeSet::new(), + output: String::new(), + indent_level: 0, + } + } + + pub fn render_class(mut self, class: &JavaClass) -> String { + // Collect imports + self.collect_imports(class); + + // Package declaration + if let Some(ref pkg) = class.package { + self.writeln(&format!("package {};", pkg)); + self.newline(); + } + + // Import declarations + if !self.imports.is_empty() { + let imports: Vec = self.imports.iter().cloned().collect(); + for import in &imports { + self.writeln(&format!("import {};", import)); + } + self.newline(); + } + + // Class declaration + self.render_class_decl(class); + + self.output + } + + fn render_class_decl(&mut self, class: &JavaClass) { + // Annotations + for ann in &class.annotations { + self.write_indent(); + self.render_annotation(ann); + self.raw_newline(); + } + + // Modifiers + kind + name + self.write_indent(); + let mut decl = String::new(); + self.append_visibility(&mut decl, &class.visibility); + if class.is_abstract && class.kind != ClassKind::Interface { + decl.push_str("abstract "); + } + if class.is_sealed { + decl.push_str("sealed "); + } + if class.is_final && class.kind != ClassKind::Enum && class.kind != ClassKind::Record { + decl.push_str("final "); + } + if class.is_static { + decl.push_str("static "); + } + + match class.kind { + ClassKind::Class => decl.push_str("class "), + ClassKind::Interface => decl.push_str("interface "), + ClassKind::Enum => decl.push_str("enum "), + ClassKind::Annotation => decl.push_str("@interface "), + ClassKind::Record => decl.push_str("record "), + } + + decl.push_str(&class.name); + + // Type parameters + if !class.type_parameters.is_empty() { + decl.push('<'); + for (i, tp) in class.type_parameters.iter().enumerate() { + if i > 0 { + decl.push_str(", "); + } + decl.push_str(&tp.name); + if !tp.bounds.is_empty() { + decl.push_str(" extends "); + for (j, b) in tp.bounds.iter().enumerate() { + if j > 0 { + decl.push_str(" & "); + } + decl.push_str(&b.display_name()); + } + } + } + decl.push('>'); + } + + // Record components + if class.kind == ClassKind::Record { + decl.push('('); + for (i, comp) in class.record_components.iter().enumerate() { + if i > 0 { + decl.push_str(", "); + } + decl.push_str(&comp.component_type.display_name()); + decl.push(' '); + decl.push_str(&comp.name); + } + decl.push(')'); + } + + // Extends + if let Some(ref super_class) = class.super_class + && class.kind == ClassKind::Class + { + decl.push_str(" extends "); + decl.push_str(&super_class.display_name()); + } + + // Implements / extends (for interfaces) + if !class.interfaces.is_empty() { + if class.kind == ClassKind::Interface { + decl.push_str(" extends "); + } else { + decl.push_str(" implements "); + } + for (i, iface) in class.interfaces.iter().enumerate() { + if i > 0 { + decl.push_str(", "); + } + decl.push_str(&iface.display_name()); + } + } + + // Permits + if !class.permitted_subclasses.is_empty() { + decl.push_str(" permits "); + for (i, sub) in class.permitted_subclasses.iter().enumerate() { + if i > 0 { + decl.push_str(", "); + } + decl.push_str(&sub.display_name()); + } + } + + decl.push_str(" {"); + self.raw(&decl); + self.raw_newline(); + self.indent_level += 1; + + // Enum constants + if class.kind == ClassKind::Enum { + self.render_enum_constants(class); + } + + // Fields + let visible_fields: Vec<&JavaField> = class + .fields + .iter() + .filter(|f| self.should_show_field(f, &class.kind)) + .collect(); + for field in &visible_fields { + self.render_field(field); + } + if !visible_fields.is_empty() { + self.newline(); + } + + // Methods + let visible_methods: Vec<&JavaMethod> = class + .methods + .iter() + .filter(|m| self.should_show_method(m, &class.kind)) + .collect(); + for (i, method) in visible_methods.iter().enumerate() { + self.render_method(method, &class.kind, &class.name); + if i + 1 < visible_methods.len() { + self.newline(); + } + } + + // Inner classes + for inner in &class.inner_classes { + self.newline(); + self.render_class_decl(inner); + } + + self.indent_level -= 1; + self.writeln("}"); + } + + fn render_enum_constants(&mut self, class: &JavaClass) { + let enum_fields: Vec<&JavaField> = + class.fields.iter().filter(|f| f.is_enum_constant).collect(); + + if !enum_fields.is_empty() { + for (i, field) in enum_fields.iter().enumerate() { + self.write_indent(); + self.raw(&field.name); + if i + 1 < enum_fields.len() { + self.raw(","); + } else { + self.raw(";"); + } + self.raw_newline(); + } + self.newline(); + } + } + + fn render_field(&mut self, field: &JavaField) { + for ann in &field.annotations { + self.write_indent(); + self.render_annotation(ann); + self.raw_newline(); + } + + self.write_indent(); + let mut decl = String::new(); + self.append_visibility(&mut decl, &field.visibility); + if field.is_static { + decl.push_str("static "); + } + if field.is_final { + decl.push_str("final "); + } + if field.is_volatile { + decl.push_str("volatile "); + } + if field.is_transient { + decl.push_str("transient "); + } + decl.push_str(&field.field_type.display_name()); + decl.push(' '); + decl.push_str(&field.name); + + if let Some(ref init) = field.initializer { + decl.push_str(" = "); + decl.push_str(&self.render_expr(init)); + } + + decl.push(';'); + self.raw(&decl); + self.raw_newline(); + } + + fn render_method(&mut self, method: &JavaMethod, class_kind: &ClassKind, class_name: &str) { + // Annotations + for ann in &method.annotations { + self.write_indent(); + self.render_annotation(ann); + self.raw_newline(); + } + + self.write_indent(); + let mut decl = String::new(); + self.append_visibility(&mut decl, &method.visibility); + if method.is_default { + decl.push_str("default "); + } + if method.is_static { + decl.push_str("static "); + } + if method.is_abstract && *class_kind != ClassKind::Interface { + decl.push_str("abstract "); + } + if method.is_final { + decl.push_str("final "); + } + if method.is_synchronized { + decl.push_str("synchronized "); + } + if method.is_native { + decl.push_str("native "); + } + + // Type parameters + if !method.type_parameters.is_empty() { + decl.push('<'); + for (i, tp) in method.type_parameters.iter().enumerate() { + if i > 0 { + decl.push_str(", "); + } + decl.push_str(&tp.name); + if !tp.bounds.is_empty() { + decl.push_str(" extends "); + for (j, b) in tp.bounds.iter().enumerate() { + if j > 0 { + decl.push_str(" & "); + } + decl.push_str(&b.display_name()); + } + } + } + decl.push_str("> "); + } + + let is_constructor = method.name == ""; + let is_static_init = method.name == ""; + + if is_static_init { + decl.clear(); + self.raw("static"); + } else if is_constructor { + decl.push_str(class_name); + } else { + decl.push_str(&method.return_type.display_name()); + decl.push(' '); + decl.push_str(&method.name); + } + + if !is_static_init { + decl.push('('); + for (i, param) in method.parameters.iter().enumerate() { + if i > 0 { + decl.push_str(", "); + } + for ann in ¶m.annotations { + self.render_annotation_to_string(ann, &mut decl); + decl.push(' '); + } + if param.is_final { + decl.push_str("final "); + } + if param.is_varargs { + // Replace last [] with ... + let type_str = param.param_type.display_name(); + if let Some(stripped) = type_str.strip_suffix("[]") { + decl.push_str(stripped); + decl.push_str("..."); + } else { + decl.push_str(&type_str); + } + } else { + decl.push_str(¶m.param_type.display_name()); + } + decl.push(' '); + decl.push_str(¶m.name); + } + decl.push(')'); + } + + // Throws + if !method.throws.is_empty() { + decl.push_str(" throws "); + for (i, t) in method.throws.iter().enumerate() { + if i > 0 { + decl.push_str(", "); + } + decl.push_str(&t.display_name()); + } + } + + self.raw(&decl); + + // Body + if let Some(ref error) = method.error { + self.raw(" {"); + self.raw_newline(); + self.indent_level += 1; + for line in error.lines() { + self.writeln(&format!("// {}", line)); + } + self.indent_level -= 1; + self.writeln("}"); + } else if let Some(ref body) = method.body { + self.raw(" {"); + self.raw_newline(); + self.indent_level += 1; + self.render_body(body); + self.indent_level -= 1; + self.writeln("}"); + } else if method.is_abstract + || method.is_native + || (*class_kind == ClassKind::Interface && !method.is_default && !method.is_static) + { + self.raw(";"); + self.raw_newline(); + } else { + self.raw(" {"); + self.raw_newline(); + self.writeln("}"); + } + } + + fn render_body(&mut self, body: &StructuredBody) { + for stmt in &body.statements { + self.render_structured_stmt(stmt); + } + } + + fn render_structured_stmt(&mut self, stmt: &StructuredStmt) { + match stmt { + StructuredStmt::Simple(s) => self.render_simple_stmt(s), + StructuredStmt::Block(stmts) => { + for s in stmts { + self.render_structured_stmt(s); + } + } + StructuredStmt::If { + condition, + then_body, + else_body, + } => { + self.write_indent(); + self.raw(&format!("if ({}) {{", self.render_expr(condition))); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(then_body); + self.indent_level -= 1; + if let Some(eb) = else_body { + self.writeln("} else {"); + self.indent_level += 1; + self.render_structured_stmt(eb); + self.indent_level -= 1; + } + self.writeln("}"); + } + StructuredStmt::While { condition, body } => { + self.write_indent(); + self.raw(&format!("while ({}) {{", self.render_expr(condition))); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(body); + self.indent_level -= 1; + self.writeln("}"); + } + StructuredStmt::DoWhile { body, condition } => { + self.writeln("do {"); + self.indent_level += 1; + self.render_structured_stmt(body); + self.indent_level -= 1; + self.write_indent(); + self.raw(&format!("}} while ({});", self.render_expr(condition))); + self.raw_newline(); + } + StructuredStmt::For { + init, + condition, + update, + body, + } => { + self.write_indent(); + let init_str = init + .as_ref() + .map(|s| self.render_stmt_inline(s)) + .unwrap_or_default(); + let update_str = update + .as_ref() + .map(|s| self.render_stmt_inline(s)) + .unwrap_or_default(); + self.raw(&format!( + "for ({}; {}; {}) {{", + init_str, + self.render_expr(condition), + update_str + )); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(body); + self.indent_level -= 1; + self.writeln("}"); + } + StructuredStmt::ForEach { + var, + iterable, + body, + } => { + self.write_indent(); + let type_name = var.ty.simple_name(); + let var_name = var.name.as_deref().unwrap_or("item"); + self.raw(&format!( + "for ({} {} : {}) {{", + type_name, + var_name, + self.render_expr(iterable) + )); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(body); + self.indent_level -= 1; + self.writeln("}"); + } + StructuredStmt::Switch { + expr, + cases, + default, + } => { + self.write_indent(); + self.raw(&format!("switch ({}) {{", self.render_expr(expr))); + self.raw_newline(); + self.indent_level += 1; + for case in cases { + self.write_indent(); + let labels: Vec = case + .values + .iter() + .map(|v| match v { + SwitchValue::Int(i) => format!("{}", i), + SwitchValue::String(s) => format!("\"{}\"", s), + SwitchValue::Enum { const_name, .. } => const_name.clone(), + }) + .collect(); + for (i, label) in labels.iter().enumerate() { + if i > 0 { + self.raw_newline(); + self.write_indent(); + } + self.raw(&format!("case {}:", label)); + } + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(&case.body); + if !case.falls_through { + self.writeln("break;"); + } + self.indent_level -= 1; + } + if let Some(def) = default { + self.writeln("default:"); + self.indent_level += 1; + self.render_structured_stmt(def); + self.writeln("break;"); + self.indent_level -= 1; + } + self.indent_level -= 1; + self.writeln("}"); + } + StructuredStmt::TryCatch { + try_body, + catches, + finally_body, + } => { + self.writeln("try {"); + self.indent_level += 1; + self.render_structured_stmt(try_body); + self.indent_level -= 1; + for catch in catches { + let exc_type = catch.exception_type.as_deref().unwrap_or("Throwable"); + let var_name = catch.var.name.as_deref().unwrap_or("e"); + self.write_indent(); + self.raw(&format!("}} catch ({} {}) {{", exc_type, var_name)); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(&catch.body); + self.indent_level -= 1; + } + if let Some(fin) = finally_body { + self.writeln("} finally {"); + self.indent_level += 1; + self.render_structured_stmt(fin); + self.indent_level -= 1; + } + self.writeln("}"); + } + StructuredStmt::TryWithResources { + resources, + body, + catches, + } => { + self.write_indent(); + self.raw("try ("); + for (i, (var, init)) in resources.iter().enumerate() { + if i > 0 { + self.raw("; "); + } + let type_name = var.ty.simple_name(); + let var_name = var.name.as_deref().unwrap_or("r"); + self.raw(&format!( + "{} {} = {}", + type_name, + var_name, + self.render_expr(init) + )); + } + self.raw(") {"); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(body); + self.indent_level -= 1; + for catch in catches { + let exc_type = catch.exception_type.as_deref().unwrap_or("Throwable"); + let var_name = catch.var.name.as_deref().unwrap_or("e"); + self.write_indent(); + self.raw(&format!("}} catch ({} {}) {{", exc_type, var_name)); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(&catch.body); + self.indent_level -= 1; + } + self.writeln("}"); + } + StructuredStmt::Synchronized { object, body } => { + self.write_indent(); + self.raw(&format!("synchronized ({}) {{", self.render_expr(object))); + self.raw_newline(); + self.indent_level += 1; + self.render_structured_stmt(body); + self.indent_level -= 1; + self.writeln("}"); + } + StructuredStmt::Labeled { label, body } => { + self.writeln(&format!("{}:", label)); + self.render_structured_stmt(body); + } + StructuredStmt::Break { label } => { + if let Some(l) = label { + self.writeln(&format!("break {};", l)); + } else { + self.writeln("break;"); + } + } + StructuredStmt::Continue { label } => { + if let Some(l) = label { + self.writeln(&format!("continue {};", l)); + } else { + self.writeln("continue;"); + } + } + StructuredStmt::Assert { condition, message } => { + self.write_indent(); + if let Some(msg) = message { + self.raw(&format!( + "assert {} : {};", + self.render_expr(condition), + self.render_expr(msg) + )); + } else { + self.raw(&format!("assert {};", self.render_expr(condition))); + } + self.raw_newline(); + } + StructuredStmt::UnstructuredGoto { target } => { + self.writeln(&format!("// goto B{}", target)); + } + StructuredStmt::Comment(text) => { + self.writeln(&format!("// {}", text)); + } + } + } + + fn render_simple_stmt(&mut self, stmt: &Stmt) { + match stmt { + Stmt::LocalStore { var, value } => { + let type_name = var.ty.simple_name(); + let default_name = format!("var{}", var.index); + let var_name = var.name.as_deref().unwrap_or(&default_name); + // TODO: Track which variables have been declared vs assigned + self.writeln(&format!( + "{} {} = {};", + type_name, + var_name, + self.render_expr(value) + )); + } + Stmt::FieldStore { + object, + class_name, + field_name, + value, + .. + } => { + let target = match object { + Some(obj) => format!("{}.{}", self.render_expr(obj), field_name), + None => format!( + "{}.{}", + descriptor::simple_class_name(class_name), + field_name + ), + }; + self.writeln(&format!("{} = {};", target, self.render_expr(value))); + } + Stmt::ArrayStore { + array, + index, + value, + } => { + self.writeln(&format!( + "{}[{}] = {};", + self.render_expr(array), + self.render_expr(index), + self.render_expr(value) + )); + } + Stmt::ExprStmt(expr) => { + self.writeln(&format!("{};", self.render_expr(expr))); + } + Stmt::Iinc { var, amount } => { + let default_name = format!("var{}", var.index); + let var_name = var.name.as_deref().unwrap_or(&default_name); + if *amount == 1 { + self.writeln(&format!("{}++;", var_name)); + } else if *amount == -1 { + self.writeln(&format!("{}--;", var_name)); + } else { + self.writeln(&format!("{} += {};", var_name, amount)); + } + } + Stmt::Return(None) => { + self.writeln("return;"); + } + Stmt::Return(Some(expr)) => { + self.writeln(&format!("return {};", self.render_expr(expr))); + } + Stmt::Throw(expr) => { + self.writeln(&format!("throw {};", self.render_expr(expr))); + } + Stmt::Monitor { enter, object } => { + if *enter { + self.writeln(&format!("// monitorenter {}", self.render_expr(object))); + } else { + self.writeln(&format!("// monitorexit {}", self.render_expr(object))); + } + } + } + } + + fn render_expr(&self, expr: &Expr) -> String { + match expr { + Expr::IntLiteral(v) => format!("{}", v), + Expr::LongLiteral(v) => format!("{}L", v), + Expr::FloatLiteral(v) => { + if v.is_nan() { + "Float.NaN".into() + } else if v.is_infinite() { + if *v > 0.0 { + "Float.POSITIVE_INFINITY".into() + } else { + "Float.NEGATIVE_INFINITY".into() + } + } else { + format!("{}f", v) + } + } + Expr::DoubleLiteral(v) => { + if v.is_nan() { + "Double.NaN".into() + } else if v.is_infinite() { + if *v > 0.0 { + "Double.POSITIVE_INFINITY".into() + } else { + "Double.NEGATIVE_INFINITY".into() + } + } else { + format!("{}d", v) + } + } + Expr::StringLiteral(s) => format!("\"{}\"", escape_java_string(s)), + Expr::ClassLiteral(name) => format!("{}.class", descriptor::simple_class_name(name)), + Expr::NullLiteral => "null".into(), + Expr::LocalLoad(var) => var + .name + .as_deref() + .unwrap_or(&format!("var{}", var.index)) + .to_string(), + Expr::This => "this".into(), + Expr::BinaryOp { op, left, right } => { + let op_str = match op { + BinOp::Add => "+", + BinOp::Sub => "-", + BinOp::Mul => "*", + BinOp::Div => "/", + BinOp::Rem => "%", + BinOp::Shl => "<<", + BinOp::Shr => ">>", + BinOp::Ushr => ">>>", + BinOp::And => "&", + BinOp::Or => "|", + BinOp::Xor => "^", + }; + format!( + "{} {} {}", + self.render_expr_parens(left, op), + op_str, + self.render_expr_parens(right, op) + ) + } + Expr::UnaryOp { op, operand } => { + let op_str = match op { + UnaryOp::Neg => "-", + UnaryOp::Not => "!", + }; + format!("{}{}", op_str, self.render_expr(operand)) + } + Expr::Cast { + target_type, + operand, + } => { + format!( + "({}){}", + target_type.simple_name(), + self.render_expr(operand) + ) + } + Expr::Instanceof { + operand, + check_type, + } => { + format!( + "{} instanceof {}", + self.render_expr(operand), + descriptor::simple_class_name(check_type) + ) + } + Expr::FieldGet { + object, + class_name, + field_name, + .. + } => match object { + Some(obj) => format!("{}.{}", self.render_expr(obj), field_name), + None => format!( + "{}.{}", + descriptor::simple_class_name(class_name), + field_name + ), + }, + Expr::MethodCall { + object, + class_name, + method_name, + args, + kind, + .. + } => { + let args_str: Vec = args.iter().map(|a| self.render_expr(a)).collect(); + let args_joined = args_str.join(", "); + match kind { + InvokeKind::Static => { + format!( + "{}.{}({})", + descriptor::simple_class_name(class_name), + method_name, + args_joined + ) + } + _ => { + let receiver = object + .as_ref() + .map(|o| self.render_expr(o)) + .unwrap_or_else(|| "this".into()); + if method_name == "" { + format!("super({})", args_joined) + } else { + format!("{}.{}({})", receiver, method_name, args_joined) + } + } + } + } + Expr::New { + class_name, args, .. + } => { + let args_str: Vec = args.iter().map(|a| self.render_expr(a)).collect(); + format!( + "new {}({})", + descriptor::simple_class_name(class_name), + args_str.join(", ") + ) + } + Expr::NewArray { + element_type, + length, + } => { + format!( + "new {}[{}]", + element_type.simple_name(), + self.render_expr(length) + ) + } + Expr::NewMultiArray { + element_type, + dimensions, + } => { + let dims: Vec = dimensions + .iter() + .map(|d| format!("[{}]", self.render_expr(d))) + .collect(); + format!("new {}{}", element_type.simple_name(), dims.join("")) + } + Expr::ArrayLength { array } => { + format!("{}.length", self.render_expr(array)) + } + Expr::ArrayLoad { array, index, .. } => { + format!("{}[{}]", self.render_expr(array), self.render_expr(index)) + } + Expr::Compare { op, left, right } => { + format!( + "{} {} {}", + self.render_expr(left), + op.as_str(), + self.render_expr(right) + ) + } + Expr::CmpResult { left, right, .. } => { + // This should be folded into a Compare during structuring + format!( + "/* cmp */ {} <=> {}", + self.render_expr(left), + self.render_expr(right) + ) + } + Expr::InvokeDynamic { + method_name, + captures, + .. + } => { + if captures.is_empty() { + format!("/* lambda */ {}()", method_name) + } else { + let caps: Vec = captures.iter().map(|c| self.render_expr(c)).collect(); + format!("/* lambda */ {}({})", method_name, caps.join(", ")) + } + } + Expr::Ternary { + condition, + then_expr, + else_expr, + } => { + format!( + "{} ? {} : {}", + self.render_expr(condition), + self.render_expr(then_expr), + self.render_expr(else_expr) + ) + } + Expr::Unresolved(msg) => msg.clone(), + Expr::Dup(inner) => self.render_expr(inner), + Expr::UninitNew { class_name } => format!( + "/* uninit */ new {}", + descriptor::simple_class_name(class_name) + ), + } + } + + fn render_expr_parens(&self, expr: &Expr, _parent_op: &BinOp) -> String { + // Add parentheses around binary operations with lower precedence + match expr { + Expr::BinaryOp { .. } => format!("({})", self.render_expr(expr)), + _ => self.render_expr(expr), + } + } + + fn render_stmt_inline(&self, stmt: &StructuredStmt) -> String { + match stmt { + StructuredStmt::Simple(s) => match s { + Stmt::LocalStore { var, value } => { + let default_name = format!("var{}", var.index); + let var_name = var.name.as_deref().unwrap_or(&default_name); + format!("{} = {}", var_name, self.render_expr(value)) + } + Stmt::Iinc { var, amount } => { + let default_name = format!("var{}", var.index); + let var_name = var.name.as_deref().unwrap_or(&default_name); + if *amount == 1 { + format!("{}++", var_name) + } else { + format!("{} += {}", var_name, amount) + } + } + Stmt::ExprStmt(e) => self.render_expr(e), + _ => "/* stmt */".into(), + }, + _ => "/* stmt */".into(), + } + } + + fn render_annotation(&mut self, ann: &JavaAnnotation) { + let mut s = String::new(); + self.render_annotation_to_string(ann, &mut s); + self.raw(&s); + } + + fn render_annotation_to_string(&self, ann: &JavaAnnotation, out: &mut String) { + out.push('@'); + out.push_str(&ann.type_name); + if !ann.arguments.is_empty() { + out.push('('); + for (i, arg) in ann.arguments.iter().enumerate() { + if i > 0 { + out.push_str(", "); + } + match arg { + AnnotationArgument::Named { name, value } => { + if name == "value" && ann.arguments.len() == 1 { + self.render_annotation_value(value, out); + } else { + out.push_str(name); + out.push_str(" = "); + self.render_annotation_value(value, out); + } + } + AnnotationArgument::Unnamed(value) => { + self.render_annotation_value(value, out); + } + } + } + out.push(')'); + } + } + + fn render_annotation_value(&self, value: &AnnotationValue, out: &mut String) { + match value { + AnnotationValue::IntLiteral(v) => write!(out, "{}", v).unwrap(), + AnnotationValue::LongLiteral(v) => write!(out, "{}L", v).unwrap(), + AnnotationValue::FloatLiteral(v) => write!(out, "{}f", v).unwrap(), + AnnotationValue::DoubleLiteral(v) => write!(out, "{}d", v).unwrap(), + AnnotationValue::StringLiteral(s) => { + write!(out, "\"{}\"", escape_java_string(s)).unwrap() + } + AnnotationValue::BooleanLiteral(b) => write!(out, "{}", b).unwrap(), + AnnotationValue::CharLiteral(c) => write!(out, "'{}'", c).unwrap(), + AnnotationValue::ClassLiteral(c) => write!(out, "{}.class", c).unwrap(), + AnnotationValue::EnumConstant { + type_name, + const_name, + } => { + write!(out, "{}.{}", type_name, const_name).unwrap(); + } + AnnotationValue::AnnotationLiteral(ann) => { + self.render_annotation_to_string(ann, out); + } + AnnotationValue::ArrayLiteral(values) => { + out.push('{'); + for (i, v) in values.iter().enumerate() { + if i > 0 { + out.push_str(", "); + } + self.render_annotation_value(v, out); + } + out.push('}'); + } + } + } + + fn should_show_field(&self, field: &JavaField, class_kind: &ClassKind) -> bool { + if !self.config.include_synthetic && field.is_synthetic { + return false; + } + // Skip enum $VALUES + if *class_kind == ClassKind::Enum && field.name == "$VALUES" { + return false; + } + // Skip enum constants in fields list (rendered separately) + if field.is_enum_constant { + return false; + } + true + } + + fn should_show_method(&self, method: &JavaMethod, class_kind: &ClassKind) -> bool { + if !self.config.include_synthetic && (method.is_synthetic || method.is_bridge) { + return false; + } + // Skip static initializer if empty + if method.name == "" && method.body.is_none() && method.error.is_none() { + return false; + } + // Skip auto-generated enum methods + if *class_kind == ClassKind::Enum && (method.name == "values" || method.name == "valueOf") { + return false; + } + true + } + + fn collect_imports(&mut self, class: &JavaClass) { + self.collect_type_imports(&class.super_class); + for iface in &class.interfaces { + self.collect_type_import(iface); + } + for field in &class.fields { + self.collect_type_import(&field.field_type); + } + for method in &class.methods { + self.collect_type_import(&method.return_type); + for param in &method.parameters { + self.collect_type_import(¶m.param_type); + } + for t in &method.throws { + self.collect_type_import(t); + } + } + } + + fn collect_type_imports(&mut self, ty: &Option) { + if let Some(t) = ty { + self.collect_type_import(t); + } + } + + fn collect_type_import(&mut self, ty: &JavaType) { + match ty { + JavaType::ClassType { + package, + name, + type_args, + } => { + if let Some(pkg) = package + && pkg != "java.lang" + { + self.imports.insert(format!("{}.{}", pkg, name)); + } + for arg in type_args { + self.collect_type_import(arg); + } + } + JavaType::ArrayType(inner) => self.collect_type_import(inner), + JavaType::WildcardType { bound: Some(b), .. } => self.collect_type_import(b), + _ => {} + } + } + + fn append_visibility(&self, out: &mut String, vis: &Visibility) { + match vis { + Visibility::Public => out.push_str("public "), + Visibility::Protected => out.push_str("protected "), + Visibility::Private => out.push_str("private "), + Visibility::PackagePrivate => {} + } + } + + fn write_indent(&mut self) { + for _ in 0..self.indent_level { + self.output.push_str(&self.config.indent); + } + } + + fn writeln(&mut self, text: &str) { + self.write_indent(); + self.output.push_str(text); + self.output.push('\n'); + } + + fn newline(&mut self) { + self.output.push('\n'); + } + + fn raw(&mut self, text: &str) { + self.output.push_str(text); + } + + fn raw_newline(&mut self) { + self.output.push('\n'); + } +} + +fn escape_java_string(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for ch in s.chars() { + match ch { + '\\' => out.push_str("\\\\"), + '"' => out.push_str("\\\""), + '\n' => out.push_str("\\n"), + '\r' => out.push_str("\\r"), + '\t' => out.push_str("\\t"), + '\0' => out.push_str("\\0"), + c if c.is_control() => { + write!(out, "\\u{:04x}", c as u32).unwrap(); + } + c => out.push(c), + } + } + out +} diff --git a/src/decompile/stack_sim.rs b/src/decompile/stack_sim.rs new file mode 100644 index 0000000..d9dcb63 --- /dev/null +++ b/src/decompile/stack_sim.rs @@ -0,0 +1,1697 @@ +use crate::attribute_info::{AttributeInfoVariant, CodeAttribute}; +use crate::code_attribute::Instruction; +use crate::constant_info::ConstantInfo; + +use super::cfg_types::*; +use super::descriptor::*; +use super::expr::*; +use super::util; + +use std::collections::HashMap; + +/// Build a lookup table from local variable index to name using the +/// LocalVariableTable sub-attribute of the CodeAttribute (when available). +fn build_local_name_table( + code_attr: &CodeAttribute, + const_pool: &[ConstantInfo], +) -> HashMap { + let mut names = HashMap::new(); + for attr in &code_attr.attributes { + if let Some(AttributeInfoVariant::LocalVariableTable(lvt)) = &attr.info_parsed { + for item in &lvt.items { + if let Some(name) = util::get_utf8(const_pool, item.name_index) { + names.insert(item.index, name.to_string()); + } + } + } + } + names +} + +/// Create a LocalVar with optional name lookup. +fn make_local(index: u16, ty: JvmType, names: &HashMap) -> LocalVar { + LocalVar { + index, + name: names.get(&index).cloned(), + ty, + } +} + +/// Resolve a constant pool field reference to (class_name, field_name, field_type). +fn resolve_field_ref(const_pool: &[ConstantInfo], index: u16) -> (String, String, JvmType) { + if let Some((class_name, field_name, descriptor)) = util::resolve_ref(const_pool, index) { + let field_type = parse_type_descriptor(descriptor).unwrap_or(JvmType::Unknown); + (class_name.to_string(), field_name.to_string(), field_type) + } else { + ( + format!("", index), + format!("", index), + JvmType::Unknown, + ) + } +} + +/// Resolve a constant pool method reference to (class_name, method_name, descriptor_str, param_types, return_type). +fn resolve_method_ref( + const_pool: &[ConstantInfo], + index: u16, +) -> (String, String, String, Vec, JvmType) { + if let Some((class_name, method_name, descriptor)) = util::resolve_ref(const_pool, index) { + let (params, ret) = + parse_method_descriptor(descriptor).unwrap_or_else(|| (vec![], JvmType::Unknown)); + ( + class_name.to_string(), + method_name.to_string(), + descriptor.to_string(), + params, + ret, + ) + } else { + ( + format!("", index), + format!("", index), + String::new(), + vec![], + JvmType::Unknown, + ) + } +} + +/// Load a constant from the constant pool by index (for ldc/ldc_w/ldc2_w). +fn load_constant(const_pool: &[ConstantInfo], index: u16) -> Expr { + match const_pool.get((index as usize).wrapping_sub(1)) { + Some(ConstantInfo::Integer(c)) => Expr::IntLiteral(c.value), + Some(ConstantInfo::Float(c)) => Expr::FloatLiteral(c.value), + Some(ConstantInfo::Long(c)) => Expr::LongLiteral(c.value), + Some(ConstantInfo::Double(c)) => Expr::DoubleLiteral(c.value), + Some(ConstantInfo::String(c)) => { + if let Some(s) = util::get_utf8(const_pool, c.string_index) { + Expr::StringLiteral(s.to_string()) + } else { + Expr::Unresolved(format!("string_cp#{}", c.string_index)) + } + } + Some(ConstantInfo::Class(c)) => { + if let Some(name) = util::get_utf8(const_pool, c.name_index) { + Expr::ClassLiteral(internal_to_source_name(name).to_string()) + } else { + Expr::Unresolved(format!("class_cp#{}", c.name_index)) + } + } + _ => Expr::Unresolved(format!("cp#{}", index)), + } +} + +/// Simulate a single basic block, converting bytecode instructions into +/// expression trees and statement lists. +pub fn simulate_block( + block: &BasicBlock, + const_pool: &[ConstantInfo], + code_attr: &CodeAttribute, + is_static: bool, +) -> SimulatedBlock { + let local_names = build_local_name_table(code_attr, const_pool); + let mut stack: Vec = Vec::new(); + let mut stmts: Vec = Vec::new(); + let mut branch_condition: Option = None; + + /// Pop from the stack or return an Unresolved placeholder. + macro_rules! pop { + ($stack:expr) => { + $stack + .pop() + .unwrap_or(Expr::Unresolved("stack_underflow".to_string())) + }; + } + + for addressed in &block.instructions { + let instr = &addressed.instruction; + match instr { + // ============================================================ + // Constants + // ============================================================ + Instruction::Iconstm1 => stack.push(Expr::IntLiteral(-1)), + Instruction::Iconst0 => stack.push(Expr::IntLiteral(0)), + Instruction::Iconst1 => stack.push(Expr::IntLiteral(1)), + Instruction::Iconst2 => stack.push(Expr::IntLiteral(2)), + Instruction::Iconst3 => stack.push(Expr::IntLiteral(3)), + Instruction::Iconst4 => stack.push(Expr::IntLiteral(4)), + Instruction::Iconst5 => stack.push(Expr::IntLiteral(5)), + + Instruction::Lconst0 => stack.push(Expr::LongLiteral(0)), + Instruction::Lconst1 => stack.push(Expr::LongLiteral(1)), + + Instruction::Fconst0 => stack.push(Expr::FloatLiteral(0.0)), + Instruction::Fconst1 => stack.push(Expr::FloatLiteral(1.0)), + Instruction::Fconst2 => stack.push(Expr::FloatLiteral(2.0)), + + Instruction::Dconst0 => stack.push(Expr::DoubleLiteral(0.0)), + Instruction::Dconst1 => stack.push(Expr::DoubleLiteral(1.0)), + + Instruction::Aconstnull => stack.push(Expr::NullLiteral), + + Instruction::Bipush(val) => stack.push(Expr::IntLiteral(*val as i32)), + Instruction::Sipush(val) => stack.push(Expr::IntLiteral(*val as i32)), + + Instruction::Ldc(idx) => stack.push(load_constant(const_pool, *idx as u16)), + Instruction::LdcW(idx) => stack.push(load_constant(const_pool, *idx)), + Instruction::Ldc2W(idx) => stack.push(load_constant(const_pool, *idx)), + + // ============================================================ + // Loads + // ============================================================ + Instruction::Iload(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx as u16, + JvmType::Int, + &local_names, + ))); + } + Instruction::Iload0 => { + stack.push(Expr::LocalLoad(make_local(0, JvmType::Int, &local_names))); + } + Instruction::Iload1 => { + stack.push(Expr::LocalLoad(make_local(1, JvmType::Int, &local_names))); + } + Instruction::Iload2 => { + stack.push(Expr::LocalLoad(make_local(2, JvmType::Int, &local_names))); + } + Instruction::Iload3 => { + stack.push(Expr::LocalLoad(make_local(3, JvmType::Int, &local_names))); + } + + Instruction::Lload(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx as u16, + JvmType::Long, + &local_names, + ))); + } + Instruction::Lload0 => { + stack.push(Expr::LocalLoad(make_local(0, JvmType::Long, &local_names))); + } + Instruction::Lload1 => { + stack.push(Expr::LocalLoad(make_local(1, JvmType::Long, &local_names))); + } + Instruction::Lload2 => { + stack.push(Expr::LocalLoad(make_local(2, JvmType::Long, &local_names))); + } + Instruction::Lload3 => { + stack.push(Expr::LocalLoad(make_local(3, JvmType::Long, &local_names))); + } + + Instruction::Fload(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx as u16, + JvmType::Float, + &local_names, + ))); + } + Instruction::Fload0 => { + stack.push(Expr::LocalLoad(make_local(0, JvmType::Float, &local_names))); + } + Instruction::Fload1 => { + stack.push(Expr::LocalLoad(make_local(1, JvmType::Float, &local_names))); + } + Instruction::Fload2 => { + stack.push(Expr::LocalLoad(make_local(2, JvmType::Float, &local_names))); + } + Instruction::Fload3 => { + stack.push(Expr::LocalLoad(make_local(3, JvmType::Float, &local_names))); + } + + Instruction::Dload(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx as u16, + JvmType::Double, + &local_names, + ))); + } + Instruction::Dload0 => { + stack.push(Expr::LocalLoad(make_local( + 0, + JvmType::Double, + &local_names, + ))); + } + Instruction::Dload1 => { + stack.push(Expr::LocalLoad(make_local( + 1, + JvmType::Double, + &local_names, + ))); + } + Instruction::Dload2 => { + stack.push(Expr::LocalLoad(make_local( + 2, + JvmType::Double, + &local_names, + ))); + } + Instruction::Dload3 => { + stack.push(Expr::LocalLoad(make_local( + 3, + JvmType::Double, + &local_names, + ))); + } + + Instruction::Aload(idx) => { + let idx16 = *idx as u16; + if idx16 == 0 && !is_static { + stack.push(Expr::This); + } else { + stack.push(Expr::LocalLoad(make_local( + idx16, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ))); + } + } + Instruction::Aload0 => { + if !is_static { + stack.push(Expr::This); + } else { + stack.push(Expr::LocalLoad(make_local( + 0, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ))); + } + } + Instruction::Aload1 => { + stack.push(Expr::LocalLoad(make_local( + 1, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ))); + } + Instruction::Aload2 => { + stack.push(Expr::LocalLoad(make_local( + 2, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ))); + } + Instruction::Aload3 => { + stack.push(Expr::LocalLoad(make_local( + 3, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ))); + } + + // Wide loads + Instruction::IloadWide(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx, + JvmType::Int, + &local_names, + ))); + } + Instruction::LloadWide(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx, + JvmType::Long, + &local_names, + ))); + } + Instruction::FloadWide(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx, + JvmType::Float, + &local_names, + ))); + } + Instruction::DloadWide(idx) => { + stack.push(Expr::LocalLoad(make_local( + *idx, + JvmType::Double, + &local_names, + ))); + } + Instruction::AloadWide(idx) => { + if *idx == 0 && !is_static { + stack.push(Expr::This); + } else { + stack.push(Expr::LocalLoad(make_local( + *idx, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ))); + } + } + + // ============================================================ + // Stores + // ============================================================ + Instruction::Istore(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx as u16, JvmType::Int, &local_names), + value: val, + }); + } + Instruction::Istore0 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(0, JvmType::Int, &local_names), + value: val, + }); + } + Instruction::Istore1 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(1, JvmType::Int, &local_names), + value: val, + }); + } + Instruction::Istore2 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(2, JvmType::Int, &local_names), + value: val, + }); + } + Instruction::Istore3 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(3, JvmType::Int, &local_names), + value: val, + }); + } + + Instruction::Lstore(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx as u16, JvmType::Long, &local_names), + value: val, + }); + } + Instruction::Lstore0 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(0, JvmType::Long, &local_names), + value: val, + }); + } + Instruction::Lstore1 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(1, JvmType::Long, &local_names), + value: val, + }); + } + Instruction::Lstore2 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(2, JvmType::Long, &local_names), + value: val, + }); + } + Instruction::Lstore3 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(3, JvmType::Long, &local_names), + value: val, + }); + } + + Instruction::Fstore(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx as u16, JvmType::Float, &local_names), + value: val, + }); + } + Instruction::Fstore0 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(0, JvmType::Float, &local_names), + value: val, + }); + } + Instruction::Fstore1 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(1, JvmType::Float, &local_names), + value: val, + }); + } + Instruction::Fstore2 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(2, JvmType::Float, &local_names), + value: val, + }); + } + Instruction::Fstore3 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(3, JvmType::Float, &local_names), + value: val, + }); + } + + Instruction::Dstore(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx as u16, JvmType::Double, &local_names), + value: val, + }); + } + Instruction::Dstore0 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(0, JvmType::Double, &local_names), + value: val, + }); + } + Instruction::Dstore1 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(1, JvmType::Double, &local_names), + value: val, + }); + } + Instruction::Dstore2 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(2, JvmType::Double, &local_names), + value: val, + }); + } + Instruction::Dstore3 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(3, JvmType::Double, &local_names), + value: val, + }); + } + + Instruction::Astore(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local( + *idx as u16, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ), + value: val, + }); + } + Instruction::Astore0 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local( + 0, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ), + value: val, + }); + } + Instruction::Astore1 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local( + 1, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ), + value: val, + }); + } + Instruction::Astore2 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local( + 2, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ), + value: val, + }); + } + Instruction::Astore3 => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local( + 3, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ), + value: val, + }); + } + + // Wide stores + Instruction::IstoreWide(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx, JvmType::Int, &local_names), + value: val, + }); + } + Instruction::LstoreWide(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx, JvmType::Long, &local_names), + value: val, + }); + } + Instruction::FstoreWide(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx, JvmType::Float, &local_names), + value: val, + }); + } + Instruction::DstoreWide(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local(*idx, JvmType::Double, &local_names), + value: val, + }); + } + Instruction::AstoreWide(idx) => { + let val = pop!(stack); + stmts.push(Stmt::LocalStore { + var: make_local( + *idx, + JvmType::Reference("java/lang/Object".to_string()), + &local_names, + ), + value: val, + }); + } + + // ============================================================ + // Array loads + // ============================================================ + Instruction::Iaload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Int, + }); + } + Instruction::Laload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Long, + }); + } + Instruction::Faload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Float, + }); + } + Instruction::Daload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Double, + }); + } + Instruction::Aaload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Reference("java/lang/Object".to_string()), + }); + } + Instruction::Baload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Byte, + }); + } + Instruction::Caload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Char, + }); + } + Instruction::Saload => { + let index = pop!(stack); + let array = pop!(stack); + stack.push(Expr::ArrayLoad { + array: Box::new(array), + index: Box::new(index), + element_type: JvmType::Short, + }); + } + + // ============================================================ + // Array stores + // ============================================================ + Instruction::Iastore + | Instruction::Lastore + | Instruction::Fastore + | Instruction::Dastore + | Instruction::Aastore + | Instruction::Bastore + | Instruction::Castore + | Instruction::Sastore => { + let value = pop!(stack); + let index = pop!(stack); + let array = pop!(stack); + stmts.push(Stmt::ArrayStore { + array, + index, + value, + }); + } + + // ============================================================ + // Stack manipulation + // ============================================================ + Instruction::Pop => { + let val = pop!(stack); + // If the popped value has side effects, emit it as a statement. + if has_side_effects(&val) { + stmts.push(Stmt::ExprStmt(val)); + } + } + Instruction::Pop2 => { + // Pop2 removes top one or two computational units. + // We treat it as two pops for simplicity. + let val1 = pop!(stack); + if has_side_effects(&val1) { + stmts.push(Stmt::ExprStmt(val1)); + } + if !stack.is_empty() { + let val2 = pop!(stack); + if has_side_effects(&val2) { + stmts.push(Stmt::ExprStmt(val2)); + } + } + } + Instruction::Dup => { + let val = pop!(stack); + let dup = Expr::Dup(Box::new(val.clone())); + stack.push(val); + stack.push(dup); + } + Instruction::Dupx1 => { + // ..., value2, value1 -> ..., value1, value2, value1 + let val1 = pop!(stack); + let val2 = pop!(stack); + let dup = Expr::Dup(Box::new(val1.clone())); + stack.push(dup); + stack.push(val2); + stack.push(val1); + } + Instruction::Dupx2 => { + // ..., value3, value2, value1 -> ..., value1, value3, value2, value1 + let val1 = pop!(stack); + let val2 = pop!(stack); + let val3 = pop!(stack); + let dup = Expr::Dup(Box::new(val1.clone())); + stack.push(dup); + stack.push(val3); + stack.push(val2); + stack.push(val1); + } + Instruction::Dup2 => { + // ..., value2, value1 -> ..., value2, value1, value2, value1 + let val1 = pop!(stack); + let val2 = pop!(stack); + let dup2 = Expr::Dup(Box::new(val2.clone())); + let dup1 = Expr::Dup(Box::new(val1.clone())); + stack.push(val2); + stack.push(val1); + stack.push(dup2); + stack.push(dup1); + } + Instruction::Dup2x1 => { + // ..., value3, value2, value1 -> ..., value2, value1, value3, value2, value1 + let val1 = pop!(stack); + let val2 = pop!(stack); + let val3 = pop!(stack); + let dup2 = Expr::Dup(Box::new(val2.clone())); + let dup1 = Expr::Dup(Box::new(val1.clone())); + stack.push(dup2); + stack.push(dup1); + stack.push(val3); + stack.push(val2); + stack.push(val1); + } + Instruction::Dup2x2 => { + // ..., value4, value3, value2, value1 -> ..., value2, value1, value4, value3, value2, value1 + let val1 = pop!(stack); + let val2 = pop!(stack); + let val3 = pop!(stack); + let val4 = pop!(stack); + let dup2 = Expr::Dup(Box::new(val2.clone())); + let dup1 = Expr::Dup(Box::new(val1.clone())); + stack.push(dup2); + stack.push(dup1); + stack.push(val4); + stack.push(val3); + stack.push(val2); + stack.push(val1); + } + Instruction::Swap => { + let val1 = pop!(stack); + let val2 = pop!(stack); + stack.push(val1); + stack.push(val2); + } + + // ============================================================ + // Arithmetic + // ============================================================ + Instruction::Iadd | Instruction::Ladd | Instruction::Fadd | Instruction::Dadd => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Add, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Isub | Instruction::Lsub | Instruction::Fsub | Instruction::Dsub => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Sub, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Imul | Instruction::Lmul | Instruction::Fmul | Instruction::Dmul => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Mul, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Idiv | Instruction::Ldiv | Instruction::Fdiv | Instruction::Ddiv => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Div, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Irem | Instruction::Lrem | Instruction::Frem | Instruction::Drem => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Rem, + left: Box::new(left), + right: Box::new(right), + }); + } + + Instruction::Ineg | Instruction::Lneg | Instruction::Fneg | Instruction::Dneg => { + let operand = pop!(stack); + stack.push(Expr::UnaryOp { + op: UnaryOp::Neg, + operand: Box::new(operand), + }); + } + + // ============================================================ + // Bitwise / shift + // ============================================================ + Instruction::Ishl | Instruction::Lshl => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Shl, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Ishr | Instruction::Lshr => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Shr, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Iushr | Instruction::Lushr => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Ushr, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Iand | Instruction::Land => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::And, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Ior | Instruction::Lor => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Or, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Ixor | Instruction::Lxor => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::BinaryOp { + op: BinOp::Xor, + left: Box::new(left), + right: Box::new(right), + }); + } + + // ============================================================ + // Conversions (casts) + // ============================================================ + Instruction::I2l => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Long, + operand: Box::new(operand), + }); + } + Instruction::I2f => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Float, + operand: Box::new(operand), + }); + } + Instruction::I2d => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Double, + operand: Box::new(operand), + }); + } + Instruction::L2i => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Int, + operand: Box::new(operand), + }); + } + Instruction::L2f => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Float, + operand: Box::new(operand), + }); + } + Instruction::L2d => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Double, + operand: Box::new(operand), + }); + } + Instruction::F2i => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Int, + operand: Box::new(operand), + }); + } + Instruction::F2l => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Long, + operand: Box::new(operand), + }); + } + Instruction::F2d => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Double, + operand: Box::new(operand), + }); + } + Instruction::D2i => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Int, + operand: Box::new(operand), + }); + } + Instruction::D2l => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Long, + operand: Box::new(operand), + }); + } + Instruction::D2f => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Float, + operand: Box::new(operand), + }); + } + Instruction::I2b => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Byte, + operand: Box::new(operand), + }); + } + Instruction::I2c => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Char, + operand: Box::new(operand), + }); + } + Instruction::I2s => { + let operand = pop!(stack); + stack.push(Expr::Cast { + target_type: JvmType::Short, + operand: Box::new(operand), + }); + } + + // ============================================================ + // Comparisons (push -1/0/1 result) + // ============================================================ + Instruction::Lcmp => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::CmpResult { + kind: CmpKind::LCmp, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Fcmpl => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::CmpResult { + kind: CmpKind::FCmpL, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Fcmpg => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::CmpResult { + kind: CmpKind::FCmpG, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Dcmpl => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::CmpResult { + kind: CmpKind::DCmpL, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::Dcmpg => { + let right = pop!(stack); + let left = pop!(stack); + stack.push(Expr::CmpResult { + kind: CmpKind::DCmpG, + left: Box::new(left), + right: Box::new(right), + }); + } + + // ============================================================ + // Conditional branches (set branch_condition) + // ============================================================ + Instruction::Ifeq(_) => { + let val = pop!(stack); + branch_condition = Some(make_if_zero_cond(val, CompareOp::Eq)); + } + Instruction::Ifne(_) => { + let val = pop!(stack); + branch_condition = Some(make_if_zero_cond(val, CompareOp::Ne)); + } + Instruction::Iflt(_) => { + let val = pop!(stack); + branch_condition = Some(make_if_zero_cond(val, CompareOp::Lt)); + } + Instruction::Ifge(_) => { + let val = pop!(stack); + branch_condition = Some(make_if_zero_cond(val, CompareOp::Ge)); + } + Instruction::Ifgt(_) => { + let val = pop!(stack); + branch_condition = Some(make_if_zero_cond(val, CompareOp::Gt)); + } + Instruction::Ifle(_) => { + let val = pop!(stack); + branch_condition = Some(make_if_zero_cond(val, CompareOp::Le)); + } + + Instruction::IfIcmpeq(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Eq, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::IfIcmpne(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Ne, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::IfIcmplt(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Lt, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::IfIcmpge(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Ge, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::IfIcmpgt(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Gt, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::IfIcmple(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Le, + left: Box::new(left), + right: Box::new(right), + }); + } + + Instruction::IfAcmpeq(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Eq, + left: Box::new(left), + right: Box::new(right), + }); + } + Instruction::IfAcmpne(_) => { + let right = pop!(stack); + let left = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Ne, + left: Box::new(left), + right: Box::new(right), + }); + } + + Instruction::Ifnull(_) => { + let val = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Eq, + left: Box::new(val), + right: Box::new(Expr::NullLiteral), + }); + } + Instruction::Ifnonnull(_) => { + let val = pop!(stack); + branch_condition = Some(Expr::Compare { + op: CompareOp::Ne, + left: Box::new(val), + right: Box::new(Expr::NullLiteral), + }); + } + + // ============================================================ + // Unconditional branches (goto, tableswitch, lookupswitch) + // These are terminators; the stack state is the exit_stack. + // ============================================================ + Instruction::Goto(_) | Instruction::GotoW(_) => { + // No stack effect; terminator is already recorded in the block. + } + + Instruction::Tableswitch { .. } | Instruction::Lookupswitch { .. } => { + // The switch key is popped from the stack. + let _key = pop!(stack); + } + + // ============================================================ + // iinc + // ============================================================ + Instruction::Iinc { index, value } => { + stmts.push(Stmt::Iinc { + var: make_local(*index as u16, JvmType::Int, &local_names), + amount: *value as i32, + }); + } + Instruction::IincWide { index, value } => { + stmts.push(Stmt::Iinc { + var: make_local(*index, JvmType::Int, &local_names), + amount: *value as i32, + }); + } + + // ============================================================ + // Field access + // ============================================================ + Instruction::Getfield(idx) => { + let (class_name, field_name, field_type) = resolve_field_ref(const_pool, *idx); + let object = pop!(stack); + stack.push(Expr::FieldGet { + object: Some(Box::new(object)), + class_name, + field_name, + field_type, + }); + } + Instruction::Getstatic(idx) => { + let (class_name, field_name, field_type) = resolve_field_ref(const_pool, *idx); + stack.push(Expr::FieldGet { + object: None, + class_name, + field_name, + field_type, + }); + } + Instruction::Putfield(idx) => { + let (class_name, field_name, field_type) = resolve_field_ref(const_pool, *idx); + let value = pop!(stack); + let object = pop!(stack); + stmts.push(Stmt::FieldStore { + object: Some(object), + class_name, + field_name, + field_type, + value, + }); + } + Instruction::Putstatic(idx) => { + let (class_name, field_name, field_type) = resolve_field_ref(const_pool, *idx); + let value = pop!(stack); + stmts.push(Stmt::FieldStore { + object: None, + class_name, + field_name, + field_type, + value, + }); + } + + // ============================================================ + // Method invocation + // ============================================================ + Instruction::Invokevirtual(idx) => { + let (class_name, method_name, descriptor, param_types, return_type) = + resolve_method_ref(const_pool, *idx); + let args = pop_args(&mut stack, ¶m_types); + let object = pop!(stack); + let call = Expr::MethodCall { + kind: InvokeKind::Virtual, + object: Some(Box::new(object)), + class_name, + method_name, + descriptor, + args, + return_type: return_type.clone(), + }; + push_or_emit_call(call, &return_type, &mut stack, &mut stmts); + } + + Instruction::Invokespecial(idx) => { + let (class_name, method_name, descriptor, param_types, return_type) = + resolve_method_ref(const_pool, *idx); + let args = pop_args(&mut stack, ¶m_types); + let receiver = pop!(stack); + + if method_name == "" { + // Detect new;dup;invokespecial pattern: collapse to Expr::New. + match receiver { + Expr::Dup(inner) => match *inner { + Expr::UninitNew { class_name: ref cn } => { + let new_expr = Expr::New { + class_name: cn.clone(), + constructor_descriptor: descriptor, + args, + }; + // The dup placed a copy on the stack; replace the original + // UninitNew that is still on the stack with the New expression. + replace_uninit_new(&mut stack, cn, &new_expr); + stack.push(new_expr); + } + _ => { + // Calling on something other than a fresh `new` (e.g., super() or this()) + let call = Expr::MethodCall { + kind: InvokeKind::Special, + object: Some(Box::new(Expr::Dup(inner))), + class_name, + method_name, + descriptor, + args, + return_type: return_type.clone(), + }; + stmts.push(Stmt::ExprStmt(call)); + } + }, + Expr::UninitNew { ref class_name } => { + // new without dup; the result is discarded or stored immediately. + let new_expr = Expr::New { + class_name: class_name.clone(), + constructor_descriptor: descriptor, + args, + }; + stack.push(new_expr); + } + Expr::This => { + // super. or this() call + let call = Expr::MethodCall { + kind: InvokeKind::Special, + object: Some(Box::new(Expr::This)), + class_name, + method_name, + descriptor, + args, + return_type: return_type.clone(), + }; + stmts.push(Stmt::ExprStmt(call)); + } + _ => { + // Generic invokespecial on unknown receiver + let call = Expr::MethodCall { + kind: InvokeKind::Special, + object: Some(Box::new(receiver)), + class_name, + method_name, + descriptor, + args, + return_type: return_type.clone(), + }; + stmts.push(Stmt::ExprStmt(call)); + } + } + } else { + // Non- invokespecial (private methods, super calls) + let call = Expr::MethodCall { + kind: InvokeKind::Special, + object: Some(Box::new(receiver)), + class_name, + method_name, + descriptor, + args, + return_type: return_type.clone(), + }; + push_or_emit_call(call, &return_type, &mut stack, &mut stmts); + } + } + + Instruction::Invokestatic(idx) => { + let (class_name, method_name, descriptor, param_types, return_type) = + resolve_method_ref(const_pool, *idx); + let args = pop_args(&mut stack, ¶m_types); + let call = Expr::MethodCall { + kind: InvokeKind::Static, + object: None, + class_name, + method_name, + descriptor, + args, + return_type: return_type.clone(), + }; + push_or_emit_call(call, &return_type, &mut stack, &mut stmts); + } + + Instruction::Invokeinterface { index, .. } => { + let (class_name, method_name, descriptor, param_types, return_type) = + resolve_method_ref(const_pool, *index); + let args = pop_args(&mut stack, ¶m_types); + let object = pop!(stack); + let call = Expr::MethodCall { + kind: InvokeKind::Interface, + object: Some(Box::new(object)), + class_name, + method_name, + descriptor, + args, + return_type: return_type.clone(), + }; + push_or_emit_call(call, &return_type, &mut stack, &mut stmts); + } + + Instruction::Invokedynamic { index, .. } => { + // Resolve the InvokeDynamic constant pool entry. + let (bootstrap_index, method_name, descriptor, param_types, return_type) = + resolve_invokedynamic(const_pool, *index); + let captures = pop_args(&mut stack, ¶m_types); + let expr = Expr::InvokeDynamic { + bootstrap_index, + method_name, + descriptor: descriptor.clone(), + captures, + }; + if return_type == JvmType::Void { + stmts.push(Stmt::ExprStmt(expr)); + } else { + stack.push(expr); + } + } + + // ============================================================ + // Object creation + // ============================================================ + Instruction::New(idx) => { + let class_name = util::get_class_name(const_pool, *idx) + .unwrap_or("") + .to_string(); + stack.push(Expr::UninitNew { class_name }); + } + + Instruction::Newarray(atype) => { + let length = pop!(stack); + let element_type = newarray_type(*atype); + stack.push(Expr::NewArray { + element_type, + length: Box::new(length), + }); + } + + Instruction::Anewarray(idx) => { + let length = pop!(stack); + let class_name = util::get_class_name(const_pool, *idx) + .unwrap_or("java/lang/Object") + .to_string(); + stack.push(Expr::NewArray { + element_type: JvmType::Reference(class_name), + length: Box::new(length), + }); + } + + Instruction::Multianewarray { index, dimensions } => { + let dim_count = *dimensions as usize; + let mut dims = Vec::with_capacity(dim_count); + for _ in 0..dim_count { + dims.push(pop!(stack)); + } + dims.reverse(); + let class_name = util::get_class_name(const_pool, *index) + .unwrap_or("[Ljava/lang/Object;") + .to_string(); + let element_type = parse_type_descriptor(&class_name).unwrap_or(JvmType::Unknown); + stack.push(Expr::NewMultiArray { + element_type, + dimensions: dims, + }); + } + + // ============================================================ + // Misc object operations + // ============================================================ + Instruction::Arraylength => { + let array = pop!(stack); + stack.push(Expr::ArrayLength { + array: Box::new(array), + }); + } + + Instruction::Checkcast(idx) => { + let operand = pop!(stack); + let class_name = util::get_class_name(const_pool, *idx) + .unwrap_or("java/lang/Object") + .to_string(); + stack.push(Expr::Cast { + target_type: JvmType::Reference(class_name), + operand: Box::new(operand), + }); + } + + Instruction::Instanceof(idx) => { + let operand = pop!(stack); + let class_name = util::get_class_name(const_pool, *idx) + .unwrap_or("java/lang/Object") + .to_string(); + stack.push(Expr::Instanceof { + operand: Box::new(operand), + check_type: class_name, + }); + } + + // ============================================================ + // Monitor + // ============================================================ + Instruction::Monitorenter => { + let object = pop!(stack); + stmts.push(Stmt::Monitor { + enter: true, + object, + }); + } + Instruction::Monitorexit => { + let object = pop!(stack); + stmts.push(Stmt::Monitor { + enter: false, + object, + }); + } + + // ============================================================ + // Returns + // ============================================================ + Instruction::Return => { + stmts.push(Stmt::Return(None)); + } + Instruction::Ireturn + | Instruction::Lreturn + | Instruction::Freturn + | Instruction::Dreturn + | Instruction::Areturn => { + let val = pop!(stack); + stmts.push(Stmt::Return(Some(val))); + } + + // ============================================================ + // Throw + // ============================================================ + Instruction::Athrow => { + let val = pop!(stack); + stmts.push(Stmt::Throw(val)); + } + + // ============================================================ + // Nop + // ============================================================ + Instruction::Nop => {} + + // ============================================================ + // jsr/ret (legacy, used for finally blocks in old javac) + // ============================================================ + Instruction::Jsr(_) | Instruction::JsrW(_) => { + // Push the return address as an unresolved marker. + stack.push(Expr::Unresolved("jsr_return_address".to_string())); + } + Instruction::Ret(_) | Instruction::RetWide(_) => { + // ret returns to the jsr caller; no stack effect modeled. + } + + // Catch-all for any unhandled instruction + #[allow(unreachable_patterns)] + other => { + stack.push(Expr::Unresolved(format!("{:?}", other))); + } + } + } + + SimulatedBlock { + id: block.id, + statements: stmts, + exit_stack: stack, + terminator: block.terminator.clone(), + branch_condition, + } +} + +/// Simulate all blocks in a control flow graph. +pub fn simulate_all_blocks( + cfg: &ControlFlowGraph, + const_pool: &[ConstantInfo], + code_attr: &CodeAttribute, + is_static: bool, +) -> Vec { + cfg.blocks + .values() + .map(|block| simulate_block(block, const_pool, code_attr, is_static)) + .collect() +} + +// --------------------------------------------------------------------------- +// Helper functions +// --------------------------------------------------------------------------- + +/// Build a branch condition for `if` opcodes that compare against zero. +/// If the operand is already a CmpResult (from lcmp/fcmp/dcmp), we can +/// fold the comparison into a direct Compare expression. +fn make_if_zero_cond(val: Expr, op: CompareOp) -> Expr { + match val { + Expr::CmpResult { + kind: _, + ref left, + ref right, + } => { + // The cmp result is compared to 0 with the given op. + // We can fold this: e.g., `fcmpl(a, b) < 0` becomes `a < b`. + Expr::Compare { + op, + left: left.clone(), + right: right.clone(), + } + } + _ => Expr::Compare { + op, + left: Box::new(val), + right: Box::new(Expr::IntLiteral(0)), + }, + } +} + +/// Pop `n` arguments from the stack (right-to-left in JVM order). +/// Returns them in left-to-right order for display. +fn pop_args(stack: &mut Vec, param_types: &[JvmType]) -> Vec { + let n = param_types.len(); + let mut args = Vec::with_capacity(n); + for _ in 0..n { + args.push( + stack + .pop() + .unwrap_or(Expr::Unresolved("missing_arg".to_string())), + ); + } + args.reverse(); + args +} + +/// If a method returns void, emit the call as a statement; otherwise push the result. +fn push_or_emit_call( + call: Expr, + return_type: &JvmType, + stack: &mut Vec, + stmts: &mut Vec, +) { + if *return_type == JvmType::Void { + stmts.push(Stmt::ExprStmt(call)); + } else { + stack.push(call); + } +} + +/// Replace the topmost UninitNew with a matching class name on the stack. +/// This handles the pattern: new Foo -> dup -> args -> invokespecial +/// After we collapse the dup+invokespecial into Expr::New, we need to +/// replace the original UninitNew that was left below the dup. +fn replace_uninit_new(stack: &mut Vec, class_name: &str, replacement: &Expr) { + for item in stack.iter_mut().rev() { + if let Expr::UninitNew { class_name: cn } = item + && cn == class_name + { + *item = replacement.clone(); + return; + } + } +} + +/// Resolve an InvokeDynamic constant pool entry. +/// Returns (bootstrap_method_attr_index, method_name, descriptor, param_types, return_type). +fn resolve_invokedynamic( + const_pool: &[ConstantInfo], + index: u16, +) -> (u16, String, String, Vec, JvmType) { + match const_pool.get((index as usize).wrapping_sub(1)) { + Some(ConstantInfo::InvokeDynamic(indy)) => { + if let Some((name, desc)) = + util::get_name_and_type(const_pool, indy.name_and_type_index) + { + let (params, ret) = + parse_method_descriptor(desc).unwrap_or_else(|| (vec![], JvmType::Unknown)); + ( + indy.bootstrap_method_attr_index, + name.to_string(), + desc.to_string(), + params, + ret, + ) + } else { + ( + indy.bootstrap_method_attr_index, + format!("", index), + String::new(), + vec![], + JvmType::Unknown, + ) + } + } + _ => ( + 0, + format!("", index), + String::new(), + vec![], + JvmType::Unknown, + ), + } +} + +/// Heuristic: does this expression likely have side effects? +/// Used to decide whether to emit a popped value as a statement. +fn has_side_effects(expr: &Expr) -> bool { + matches!( + expr, + Expr::MethodCall { .. } + | Expr::New { .. } + | Expr::InvokeDynamic { .. } + | Expr::Unresolved(_) + ) +} diff --git a/src/decompile/structured_types.rs b/src/decompile/structured_types.rs new file mode 100644 index 0000000..2109213 --- /dev/null +++ b/src/decompile/structured_types.rs @@ -0,0 +1,121 @@ +use super::cfg_types::BlockId; +use super::expr::{Expr, LocalVar, Stmt}; + +/// A structured statement — the result of control flow structuring. +/// Represents Java-level control flow constructs. +#[derive(Clone, Debug)] +pub enum StructuredStmt { + /// A simple statement (from stack simulation). + Simple(Stmt), + /// A sequence of statements. + Block(Vec), + /// if / if-else + If { + condition: Expr, + then_body: Box, + else_body: Option>, + }, + /// while loop + While { + condition: Expr, + body: Box, + }, + /// do-while loop + DoWhile { + body: Box, + condition: Expr, + }, + /// for loop + For { + init: Option>, + condition: Expr, + update: Option>, + body: Box, + }, + /// for-each loop (desugared from iterator or array index pattern) + ForEach { + var: LocalVar, + iterable: Expr, + body: Box, + }, + /// switch statement + Switch { + expr: Expr, + cases: Vec, + default: Option>, + }, + /// try-catch-finally + TryCatch { + try_body: Box, + catches: Vec, + finally_body: Option>, + }, + /// try-with-resources (desugared) + TryWithResources { + resources: Vec<(LocalVar, Expr)>, + body: Box, + catches: Vec, + }, + /// synchronized block + Synchronized { + object: Expr, + body: Box, + }, + /// Labeled statement (for break/continue targets) + Labeled { + label: String, + body: Box, + }, + /// break statement + Break { label: Option }, + /// continue statement + Continue { label: Option }, + /// assert statement (desugared) + Assert { + condition: Expr, + message: Option, + }, + /// Fallback for irreducible control flow + UnstructuredGoto { target: BlockId }, + /// Comment (used for error recovery, bytecode fallback, etc.) + Comment(String), +} + +/// A switch case arm. +#[derive(Clone, Debug)] +pub struct SwitchCase { + pub values: Vec, + pub body: StructuredStmt, + pub falls_through: bool, +} + +/// Value for a switch case label. +#[derive(Clone, Debug)] +pub enum SwitchValue { + Int(i32), + String(String), + Enum { + type_name: String, + const_name: String, + }, +} + +/// A catch clause in a try-catch. +#[derive(Clone, Debug)] +pub struct CatchClause { + pub exception_type: Option, + pub var: LocalVar, + pub body: StructuredStmt, +} + +/// A structured method body: the sequence of structured statements. +#[derive(Clone, Debug)] +pub struct StructuredBody { + pub statements: Vec, +} + +impl StructuredBody { + pub fn new(statements: Vec) -> Self { + Self { statements } + } +} diff --git a/src/decompile/structuring.rs b/src/decompile/structuring.rs new file mode 100644 index 0000000..b9ef72c --- /dev/null +++ b/src/decompile/structuring.rs @@ -0,0 +1,529 @@ +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; + +use super::cfg_types::*; +use super::expr::*; +use super::structured_types::*; + +/// Convert a CFG with simulated blocks into a structured body. +pub fn structure_method( + cfg: &ControlFlowGraph, + simulated: &[SimulatedBlock], + const_pool: &[crate::constant_info::ConstantInfo], +) -> StructuredBody { + let sim_map: BTreeMap = simulated.iter().map(|b| (b.id, b)).collect(); + + if cfg.blocks.is_empty() { + return StructuredBody::new(vec![]); + } + + let rpo = cfg.reverse_postorder(); + let dominators = compute_dominators(cfg, &rpo); + let loop_headers = find_loop_headers(cfg, &dominators); + let post_dominators = compute_post_dominators(cfg, &rpo); + + let mut ctx = StructuringContext { + cfg, + sim_map: &sim_map, + dominators, + post_dominators, + loop_headers, + const_pool, + visited: HashSet::new(), + label_counter: 0, + }; + + let stmts = ctx.structure_region(&rpo); + StructuredBody::new(stmts) +} + +struct StructuringContext<'a> { + cfg: &'a ControlFlowGraph, + sim_map: &'a BTreeMap, + dominators: HashMap, + post_dominators: HashMap, + loop_headers: HashSet, + const_pool: &'a [crate::constant_info::ConstantInfo], + visited: HashSet, + label_counter: usize, +} + +impl<'a> StructuringContext<'a> { + fn fresh_label(&mut self) -> String { + self.label_counter += 1; + format!("label{}", self.label_counter) + } + + fn structure_region(&mut self, order: &[BlockId]) -> Vec { + let mut result = Vec::new(); + + for &block_id in order { + if self.visited.contains(&block_id) { + continue; + } + self.visited.insert(block_id); + + let sim = match self.sim_map.get(&block_id) { + Some(s) => *s, + None => continue, + }; + + // Emit the block's statements + for stmt in &sim.statements { + result.push(StructuredStmt::Simple(stmt.clone())); + } + + // Structure the terminator + match &sim.terminator { + Terminator::Return | Terminator::Throw => { + // Statements already include the Return/Throw + } + Terminator::FallThrough { target } => { + if !self.visited.contains(target) && self.loop_headers.contains(target) { + // Back-edge to a loop header — emit continue + result.push(StructuredStmt::Continue { label: None }); + } + // Otherwise the next block in order will handle it + } + Terminator::Goto { target } => { + if self.visited.contains(target) { + if self.loop_headers.contains(target) { + result.push(StructuredStmt::Continue { label: None }); + } else { + result.push(StructuredStmt::UnstructuredGoto { target: *target }); + } + } + // Forward goto is handled by visiting the target later + } + Terminator::ConditionalBranch { + condition: _, + if_true, + if_false, + } => { + let cond_expr = sim + .branch_condition + .clone() + .unwrap_or_else(|| Expr::Unresolved("/* condition */".into())); + + let if_true = *if_true; + let if_false = *if_false; + + // Check if this is a loop header + if self.loop_headers.contains(&block_id) { + let body_entry = if_true; + let exit = if_false; + + // Check which branch goes back (is the loop body) + let (body_start, loop_exit, negate) = if self.dominates(block_id, if_true) + && !self.visited.contains(&if_true) + { + (if_true, if_false, false) + } else if self.dominates(block_id, if_false) + && !self.visited.contains(&if_false) + { + (if_false, if_true, true) + } else { + (body_entry, exit, false) + }; + + let condition = if negate { + negate_expr(cond_expr) + } else { + cond_expr + }; + + // Collect loop body blocks + let loop_body_order = self.collect_loop_body(block_id, body_start); + let body_stmts = self.structure_region(&loop_body_order); + + let body = if body_stmts.is_empty() { + StructuredStmt::Block(vec![]) + } else { + StructuredStmt::Block(body_stmts) + }; + + result.push(StructuredStmt::While { + condition, + body: Box::new(body), + }); + + // Continue with the loop exit + if !self.visited.contains(&loop_exit) { + let exit_stmts = self.structure_region(&[loop_exit]); + result.extend(exit_stmts); + } + } else { + // If-else structure + let join_point = self.post_dominators.get(&block_id).copied(); + + let then_stmts = if !self.visited.contains(&if_true) { + let then_order = self.collect_until(if_true, join_point); + self.structure_region(&then_order) + } else { + vec![] + }; + + let else_stmts = if !self.visited.contains(&if_false) { + let else_order = self.collect_until(if_false, join_point); + self.structure_region(&else_order) + } else { + vec![] + }; + + let then_body = Box::new(StructuredStmt::Block(then_stmts)); + let else_body = if else_stmts.is_empty() { + None + } else { + Some(Box::new(StructuredStmt::Block(else_stmts))) + }; + + result.push(StructuredStmt::If { + condition: cond_expr, + then_body, + else_body, + }); + + // Continue with the join point + if let Some(jp) = join_point + && !self.visited.contains(&jp) + { + let jp_stmts = self.structure_region(&[jp]); + result.extend(jp_stmts); + } + } + } + Terminator::TableSwitch { + default, + low, + high: _, + targets, + } => { + let switch_expr = sim + .exit_stack + .last() + .cloned() + .unwrap_or(Expr::Unresolved("/* switch expr */".into())); + + let mut cases = Vec::new(); + for (i, &target) in targets.iter().enumerate() { + let value = *low + i as i32; + if !self.visited.contains(&target) { + self.visited.insert(target); + let body_stmts = if let Some(s) = self.sim_map.get(&target) { + s.statements + .iter() + .map(|st| StructuredStmt::Simple(st.clone())) + .collect() + } else { + vec![] + }; + cases.push(SwitchCase { + values: vec![SwitchValue::Int(value)], + body: StructuredStmt::Block(body_stmts), + falls_through: false, + }); + } + } + + let default_body = if !self.visited.contains(default) { + self.visited.insert(*default); + self.sim_map.get(default).map(|s| { + Box::new(StructuredStmt::Block( + s.statements + .iter() + .map(|st| StructuredStmt::Simple(st.clone())) + .collect(), + )) + }) + } else { + None + }; + + result.push(StructuredStmt::Switch { + expr: switch_expr, + cases, + default: default_body, + }); + } + Terminator::LookupSwitch { default, pairs } => { + let switch_expr = sim + .exit_stack + .last() + .cloned() + .unwrap_or(Expr::Unresolved("/* switch expr */".into())); + + let mut cases = Vec::new(); + for (key, target) in pairs { + if !self.visited.contains(target) { + self.visited.insert(*target); + let body_stmts = if let Some(s) = self.sim_map.get(target) { + s.statements + .iter() + .map(|st| StructuredStmt::Simple(st.clone())) + .collect() + } else { + vec![] + }; + cases.push(SwitchCase { + values: vec![SwitchValue::Int(*key)], + body: StructuredStmt::Block(body_stmts), + falls_through: false, + }); + } + } + + let default_body = if !self.visited.contains(default) { + self.visited.insert(*default); + self.sim_map.get(default).map(|s| { + Box::new(StructuredStmt::Block( + s.statements + .iter() + .map(|st| StructuredStmt::Simple(st.clone())) + .collect(), + )) + }) + } else { + None + }; + + result.push(StructuredStmt::Switch { + expr: switch_expr, + cases, + default: default_body, + }); + } + Terminator::Jsr { .. } => { + result.push(StructuredStmt::Comment("/* jsr subroutine */".into())); + } + } + } + + // Handle exception edges -> try-catch + self.structure_exception_handlers(&mut result); + + result + } + + fn structure_exception_handlers(&self, _result: &mut Vec) { + // Exception handler structuring is done at a higher level in class_decompiler + // when we have full context. For now, exception edges create additional entry + // points that are visited as part of the normal flow. + } + + fn dominates(&self, a: BlockId, b: BlockId) -> bool { + let mut current = b; + loop { + if current == a { + return true; + } + match self.dominators.get(¤t) { + Some(&dom) if dom != current => current = dom, + _ => return false, + } + } + } + + fn collect_loop_body(&self, header: BlockId, body_start: BlockId) -> Vec { + // Collect all blocks reachable from body_start that are dominated by header + let mut body = Vec::new(); + let mut worklist = vec![body_start]; + let mut seen = HashSet::new(); + seen.insert(header); // Don't re-visit the header + + while let Some(bid) = worklist.pop() { + if !seen.insert(bid) { + continue; + } + if !self.cfg.blocks.contains_key(&bid) { + continue; + } + body.push(bid); + for succ in self.cfg.successors(bid) { + if !seen.contains(&succ) { + worklist.push(succ); + } + } + } + body.sort(); + body + } + + fn collect_until(&self, start: BlockId, stop: Option) -> Vec { + let mut result = Vec::new(); + let mut worklist = vec![start]; + let mut seen = HashSet::new(); + + while let Some(bid) = worklist.pop() { + if let Some(stop_id) = stop + && bid == stop_id + { + continue; + } + if !seen.insert(bid) { + continue; + } + if !self.cfg.blocks.contains_key(&bid) { + continue; + } + result.push(bid); + for succ in self.cfg.successors(bid) { + if !seen.contains(&succ) { + worklist.push(succ); + } + } + } + result.sort(); + result + } +} + +/// Negate an expression (for inverting branch conditions). +pub fn negate_expr(expr: Expr) -> Expr { + match expr { + Expr::Compare { op, left, right } => Expr::Compare { + op: op.negate(), + left, + right, + }, + Expr::UnaryOp { + op: UnaryOp::Not, + operand, + } => *operand, + other => Expr::UnaryOp { + op: UnaryOp::Not, + operand: Box::new(other), + }, + } +} + +/// Compute immediate dominators using a simple iterative algorithm. +fn compute_dominators(cfg: &ControlFlowGraph, rpo: &[BlockId]) -> HashMap { + let mut doms: HashMap = HashMap::new(); + let entry = cfg.entry; + doms.insert(entry, entry); + + let rpo_index: HashMap = rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect(); + + let mut changed = true; + while changed { + changed = false; + for &b in rpo { + if b == entry { + continue; + } + let preds = cfg.predecessors(b); + let mut new_idom: Option = None; + for p in &preds { + if !doms.contains_key(p) { + continue; + } + new_idom = Some(match new_idom { + None => *p, + Some(current) => intersect(&doms, &rpo_index, current, *p), + }); + } + if let Some(idom) = new_idom + && doms.get(&b) != Some(&idom) + { + doms.insert(b, idom); + changed = true; + } + } + } + + doms +} + +fn intersect( + doms: &HashMap, + rpo_index: &HashMap, + mut b1: BlockId, + mut b2: BlockId, +) -> BlockId { + while b1 != b2 { + let idx1 = rpo_index.get(&b1).copied().unwrap_or(usize::MAX); + let idx2 = rpo_index.get(&b2).copied().unwrap_or(usize::MAX); + if idx1 > idx2 { + b1 = *doms.get(&b1).unwrap_or(&b1); + } else { + b2 = *doms.get(&b2).unwrap_or(&b2); + } + } + b1 +} + +/// Compute post-dominators (dominators of the reverse CFG). +fn compute_post_dominators(cfg: &ControlFlowGraph, _rpo: &[BlockId]) -> HashMap { + // Find exit blocks (Return/Throw terminators) + let _exit_blocks: Vec = cfg + .blocks + .iter() + .filter(|(_, b)| matches!(b.terminator, Terminator::Return | Terminator::Throw)) + .map(|(&id, _)| id) + .collect(); + + // Simple post-dominator: for each block with a ConditionalBranch, + // find the nearest block where both branches reconverge + let mut post_doms = HashMap::new(); + + for (&block_id, block) in &cfg.blocks { + if let Terminator::ConditionalBranch { + if_true, if_false, .. + } = &block.terminator + { + // Find the first block reachable from both branches + let reachable_true = reachable_set(cfg, *if_true); + let reachable_false = reachable_set(cfg, *if_false); + let common: BTreeSet = reachable_true + .intersection(&reachable_false) + .copied() + .collect(); + // The post-dominator is the first common block in order + if let Some(&first_common) = common.iter().next() { + post_doms.insert(block_id, first_common); + } + } + } + + post_doms +} + +fn reachable_set(cfg: &ControlFlowGraph, start: BlockId) -> BTreeSet { + let mut visited = BTreeSet::new(); + let mut worklist = vec![start]; + while let Some(b) = worklist.pop() { + if !visited.insert(b) { + continue; + } + if let Some(_block) = cfg.blocks.get(&b) { + for succ in cfg.successors(b) { + worklist.push(succ); + } + } + } + visited +} + +/// Find natural loop headers (blocks that are targets of back-edges). +fn find_loop_headers( + cfg: &ControlFlowGraph, + dominators: &HashMap, +) -> HashSet { + let mut headers = HashSet::new(); + for &block_id in cfg.blocks.keys() { + for succ in cfg.successors(block_id) { + // A back-edge is an edge where the target dominates the source + let mut current = block_id; + loop { + if current == succ { + headers.insert(succ); + break; + } + match dominators.get(¤t) { + Some(&dom) if dom != current => current = dom, + _ => break, + } + } + } + } + headers +} diff --git a/src/decompile/type_inference.rs b/src/decompile/type_inference.rs new file mode 100644 index 0000000..a76f94a --- /dev/null +++ b/src/decompile/type_inference.rs @@ -0,0 +1,623 @@ +use crate::attribute_info::*; +use crate::constant_info::ConstantInfo; +use crate::field_info::{FieldAccessFlags, FieldInfo}; +use crate::method_info::{MethodAccessFlags, MethodInfo}; +use crate::types::{ClassAccessFlags, ClassFile}; + +use super::descriptor::{self, JvmType}; +use super::expr::Expr; +use super::java_ast::*; +use super::util; + +/// Build a JavaClass from a parsed ClassFile. +pub fn build_java_class(class: &ClassFile) -> JavaClass { + let const_pool = &class.const_pool; + + // Class name + let full_name = util::get_class_name(const_pool, class.this_class).unwrap_or("Unknown"); + let (package, simple_name) = split_class_name(full_name); + + // Super class + let super_class = if class.super_class != 0 { + let super_name = + util::get_class_name(const_pool, class.super_class).unwrap_or("java/lang/Object"); + if super_name != "java/lang/Object" { + Some(internal_name_to_java_type(super_name)) + } else { + None + } + } else { + None + }; + + // Interfaces + let interfaces: Vec = class + .interfaces + .iter() + .filter_map(|&idx| { + let name = util::get_class_name(const_pool, idx)?; + Some(internal_name_to_java_type(name)) + }) + .collect(); + + // Determine class kind + let kind = determine_class_kind(class); + let visibility = class_visibility(class.access_flags); + let is_final = class.access_flags.contains(ClassAccessFlags::FINAL); + let is_abstract = class.access_flags.contains(ClassAccessFlags::ABSTRACT); + + // Check for sealed (PermittedSubclasses attribute) + let (is_sealed, permitted_subclasses) = extract_permitted_subclasses(class); + + // Record components + let record_components = extract_record_components(class); + + // Source file + let source_file = extract_source_file(class); + + // Annotations + let annotations = extract_class_annotations(class); + + // Type parameters from Signature attribute + let type_parameters = extract_class_type_parameters(class); + + // Fields + let fields: Vec = class + .fields + .iter() + .map(|f| build_java_field(f, const_pool)) + .collect(); + + // Methods + let methods: Vec = class + .methods + .iter() + .map(|m| build_java_method(m, const_pool, &kind)) + .collect(); + + JavaClass { + kind, + visibility, + is_final, + is_abstract, + is_sealed, + is_static: false, + annotations, + type_parameters, + package, + name: simple_name, + super_class, + interfaces, + permitted_subclasses, + record_components, + fields, + methods, + inner_classes: Vec::new(), + source_file, + } +} + +fn determine_class_kind(class: &ClassFile) -> ClassKind { + let flags = class.access_flags; + if flags.contains(ClassAccessFlags::ANNOTATION) { + ClassKind::Annotation + } else if flags.contains(ClassAccessFlags::ENUM) { + ClassKind::Enum + } else if flags.contains(ClassAccessFlags::INTERFACE) { + ClassKind::Interface + } else if has_record_attribute(class) { + ClassKind::Record + } else { + ClassKind::Class + } +} + +fn has_record_attribute(class: &ClassFile) -> bool { + class + .attributes + .iter() + .any(|a| matches!(&a.info_parsed, Some(AttributeInfoVariant::Record(_)))) +} + +fn class_visibility(flags: ClassAccessFlags) -> Visibility { + if flags.contains(ClassAccessFlags::PUBLIC) { + Visibility::Public + } else { + Visibility::PackagePrivate + } +} + +fn method_visibility(flags: MethodAccessFlags) -> Visibility { + if flags.contains(MethodAccessFlags::PUBLIC) { + Visibility::Public + } else if flags.contains(MethodAccessFlags::PROTECTED) { + Visibility::Protected + } else if flags.contains(MethodAccessFlags::PRIVATE) { + Visibility::Private + } else { + Visibility::PackagePrivate + } +} + +fn field_visibility(flags: FieldAccessFlags) -> Visibility { + if flags.contains(FieldAccessFlags::PUBLIC) { + Visibility::Public + } else if flags.contains(FieldAccessFlags::PROTECTED) { + Visibility::Protected + } else if flags.contains(FieldAccessFlags::PRIVATE) { + Visibility::Private + } else { + Visibility::PackagePrivate + } +} + +fn split_class_name(internal_name: &str) -> (Option, String) { + match internal_name.rfind('/') { + Some(pos) => { + let pkg = internal_name[..pos].replace('/', "."); + let name = internal_name[pos + 1..].to_string(); + // Handle inner classes: Outer$Inner -> Inner + let simple = match name.rfind('$') { + Some(dpos) => name[dpos + 1..].to_string(), + None => name, + }; + (Some(pkg), simple) + } + None => (None, internal_name.to_string()), + } +} + +fn internal_name_to_java_type(name: &str) -> JavaType { + let _source = descriptor::internal_to_source_name(name); + let simple = descriptor::simple_class_name(name).to_string(); + let package = descriptor::package_name(name).map(|p| p.replace('/', ".")); + JavaType::ClassType { + package, + name: simple, + type_args: Vec::new(), + } +} + +fn jvm_type_to_java_type(ty: &JvmType) -> JavaType { + match ty { + JvmType::Int => JavaType::Primitive(PrimitiveType::Int), + JvmType::Long => JavaType::Primitive(PrimitiveType::Long), + JvmType::Float => JavaType::Primitive(PrimitiveType::Float), + JvmType::Double => JavaType::Primitive(PrimitiveType::Double), + JvmType::Byte => JavaType::Primitive(PrimitiveType::Byte), + JvmType::Char => JavaType::Primitive(PrimitiveType::Char), + JvmType::Short => JavaType::Primitive(PrimitiveType::Short), + JvmType::Boolean => JavaType::Primitive(PrimitiveType::Boolean), + JvmType::Void => JavaType::Void, + JvmType::Reference(name) => internal_name_to_java_type(name), + JvmType::Array(inner) => JavaType::ArrayType(Box::new(jvm_type_to_java_type(inner))), + JvmType::Null | JvmType::Unknown => JavaType::ClassType { + package: Some("java.lang".into()), + name: "Object".into(), + type_args: Vec::new(), + }, + } +} + +fn build_java_field(field: &FieldInfo, const_pool: &[ConstantInfo]) -> JavaField { + let name = util::get_utf8(const_pool, field.name_index) + .unwrap_or("unknown") + .to_string(); + let desc = util::get_utf8(const_pool, field.descriptor_index).unwrap_or("I"); + let jvm_type = descriptor::parse_type_descriptor(desc).unwrap_or(JvmType::Unknown); + let field_type = jvm_type_to_java_type(&jvm_type); + let flags = field.access_flags; + + // Check for ConstantValue attribute (static final initializer) + let initializer = extract_field_initializer(field, const_pool); + + // Annotations + let annotations = extract_field_annotations(field, const_pool); + + JavaField { + visibility: field_visibility(flags), + is_static: flags.contains(FieldAccessFlags::STATIC), + is_final: flags.contains(FieldAccessFlags::FINAL), + is_volatile: flags.contains(FieldAccessFlags::VOLATILE), + is_transient: flags.contains(FieldAccessFlags::TRANSIENT), + is_synthetic: flags.contains(FieldAccessFlags::SYNTHETIC), + is_enum_constant: flags.contains(FieldAccessFlags::ENUM), + annotations, + field_type, + name, + initializer, + } +} + +fn build_java_method( + method: &MethodInfo, + const_pool: &[ConstantInfo], + class_kind: &ClassKind, +) -> JavaMethod { + let name = util::get_utf8(const_pool, method.name_index) + .unwrap_or("unknown") + .to_string(); + let desc = util::get_utf8(const_pool, method.descriptor_index).unwrap_or("()V"); + let flags = method.access_flags; + + let (param_types, ret_type) = + descriptor::parse_method_descriptor(desc).unwrap_or((vec![], JvmType::Void)); + + let return_type = jvm_type_to_java_type(&ret_type); + + // Build parameters with names from MethodParameters or LocalVariableTable + let param_names = extract_parameter_names(method, const_pool, param_types.len()); + let parameters: Vec = param_types + .iter() + .enumerate() + .map(|(i, ty)| { + let param_name = param_names + .get(i) + .cloned() + .unwrap_or_else(|| format!("param{}", i)); + JavaParameter { + annotations: Vec::new(), + param_type: jvm_type_to_java_type(ty), + name: param_name, + is_final: false, + is_varargs: i == param_types.len() - 1 + && flags.contains(MethodAccessFlags::VARARGS), + } + }) + .collect(); + + // Throws clause + let throws = extract_throws(method, const_pool); + + // Annotations + let annotations = extract_method_annotations(method, const_pool); + + let is_default = *class_kind == ClassKind::Interface + && !flags.contains(MethodAccessFlags::ABSTRACT) + && !flags.contains(MethodAccessFlags::STATIC); + + JavaMethod { + visibility: method_visibility(flags), + is_static: flags.contains(MethodAccessFlags::STATIC), + is_final: flags.contains(MethodAccessFlags::FINAL), + is_abstract: flags.contains(MethodAccessFlags::ABSTRACT), + is_synchronized: flags.contains(MethodAccessFlags::SYNCHRONIZED), + is_native: flags.contains(MethodAccessFlags::NATIVE), + is_default, + is_synthetic: flags.contains(MethodAccessFlags::SYNTHETIC), + is_bridge: flags.contains(MethodAccessFlags::BRIDGE), + annotations, + type_parameters: Vec::new(), + return_type, + name, + parameters, + throws, + body: None, // Populated later by the decompiler + error: None, + } +} + +fn extract_parameter_names( + method: &MethodInfo, + const_pool: &[ConstantInfo], + param_count: usize, +) -> Vec { + // Try MethodParameters attribute first + for attr in &method.attributes { + if let Some(AttributeInfoVariant::MethodParameters(mp)) = &attr.info_parsed { + return mp + .parameters + .iter() + .map(|p| { + if p.name_index != 0 { + util::get_utf8(const_pool, p.name_index) + .unwrap_or("param") + .to_string() + } else { + "param".to_string() + } + }) + .collect(); + } + } + + // Try LocalVariableTable from Code attribute + if let Some(code) = method.code() { + for attr in &code.attributes { + if let Some(AttributeInfoVariant::LocalVariableTable(lvt)) = &attr.info_parsed { + let is_static = method.access_flags.contains(MethodAccessFlags::STATIC); + let start_idx: u16 = if is_static { 0 } else { 1 }; + let mut names = Vec::new(); + for i in 0..param_count { + let slot = start_idx + i as u16; + let name = lvt + .items + .iter() + .find(|item| item.index == slot && item.start_pc == 0) + .and_then(|item| util::get_utf8(const_pool, item.name_index)) + .unwrap_or("param") + .to_string(); + names.push(name); + } + return names; + } + } + } + + // Fallback + (0..param_count).map(|i| format!("param{}", i)).collect() +} + +fn extract_throws(method: &MethodInfo, const_pool: &[ConstantInfo]) -> Vec { + for attr in &method.attributes { + if let Some(AttributeInfoVariant::Exceptions(exc)) = &attr.info_parsed { + return exc + .exception_table + .iter() + .filter_map(|&idx| { + let name = util::get_class_name(const_pool, idx)?; + Some(internal_name_to_java_type(name)) + }) + .collect(); + } + } + Vec::new() +} + +fn extract_field_initializer(field: &FieldInfo, const_pool: &[ConstantInfo]) -> Option { + for attr in &field.attributes { + if let Some(AttributeInfoVariant::ConstantValue(cv)) = &attr.info_parsed { + let idx = cv.constant_value_index; + return match const_pool.get((idx as usize).checked_sub(1)?) { + Some(ConstantInfo::Integer(c)) => Some(Expr::IntLiteral(c.value)), + Some(ConstantInfo::Long(c)) => Some(Expr::LongLiteral(c.value)), + Some(ConstantInfo::Float(c)) => Some(Expr::FloatLiteral(c.value)), + Some(ConstantInfo::Double(c)) => Some(Expr::DoubleLiteral(c.value)), + Some(ConstantInfo::String(s)) => { + let string = util::get_utf8(const_pool, s.string_index)?.to_string(); + Some(Expr::StringLiteral(string)) + } + _ => None, + }; + } + } + None +} + +fn extract_permitted_subclasses(class: &ClassFile) -> (bool, Vec) { + for attr in &class.attributes { + if let Some(AttributeInfoVariant::PermittedSubclasses(ps)) = &attr.info_parsed { + let types: Vec = ps + .classes + .iter() + .filter_map(|&idx| { + let name = util::get_class_name(&class.const_pool, idx)?; + Some(internal_name_to_java_type(name)) + }) + .collect(); + return (true, types); + } + } + (false, Vec::new()) +} + +fn extract_record_components(class: &ClassFile) -> Vec { + for attr in &class.attributes { + if let Some(AttributeInfoVariant::Record(rec)) = &attr.info_parsed { + return rec + .components + .iter() + .filter_map(|c| { + let name = util::get_utf8(&class.const_pool, c.name_index)?.to_string(); + let desc = util::get_utf8(&class.const_pool, c.descriptor_index)?; + let jvm_type = descriptor::parse_type_descriptor(desc)?; + Some(RecordComponent { + annotations: Vec::new(), + component_type: jvm_type_to_java_type(&jvm_type), + name, + }) + }) + .collect(); + } + } + Vec::new() +} + +fn extract_source_file(class: &ClassFile) -> Option { + for attr in &class.attributes { + if let Some(AttributeInfoVariant::SourceFile(sf)) = &attr.info_parsed { + return util::get_utf8(&class.const_pool, sf.sourcefile_index).map(|s| s.to_string()); + } + } + None +} + +fn extract_class_annotations(class: &ClassFile) -> Vec { + let mut annotations = Vec::new(); + for attr in &class.attributes { + match &attr.info_parsed { + Some(AttributeInfoVariant::RuntimeVisibleAnnotations(ra)) => { + for ann in &ra.annotations { + if let Some(a) = convert_annotation(ann, &class.const_pool) { + annotations.push(a); + } + } + } + Some(AttributeInfoVariant::Deprecated(_)) => { + annotations.push(JavaAnnotation { + type_name: "Deprecated".into(), + arguments: Vec::new(), + }); + } + _ => {} + } + } + annotations +} + +fn extract_field_annotations( + field: &FieldInfo, + const_pool: &[ConstantInfo], +) -> Vec { + let mut annotations = Vec::new(); + for attr in &field.attributes { + if let Some(AttributeInfoVariant::RuntimeVisibleAnnotations(ra)) = &attr.info_parsed { + for ann in &ra.annotations { + if let Some(a) = convert_annotation(ann, const_pool) { + annotations.push(a); + } + } + } + } + annotations +} + +fn extract_method_annotations( + method: &MethodInfo, + const_pool: &[ConstantInfo], +) -> Vec { + let mut annotations = Vec::new(); + for attr in &method.attributes { + match &attr.info_parsed { + Some(AttributeInfoVariant::RuntimeVisibleAnnotations(ra)) => { + for ann in &ra.annotations { + if let Some(a) = convert_annotation(ann, const_pool) { + annotations.push(a); + } + } + } + Some(AttributeInfoVariant::Deprecated(_)) => { + annotations.push(JavaAnnotation { + type_name: "Deprecated".into(), + arguments: Vec::new(), + }); + } + _ => {} + } + } + annotations +} + +fn convert_annotation( + ann: &RuntimeAnnotation, + const_pool: &[ConstantInfo], +) -> Option { + let type_desc = util::get_utf8(const_pool, ann.type_index)?; + // Type descriptor is like "Ljava/lang/Override;" -> "Override" + let type_name = if type_desc.starts_with('L') && type_desc.ends_with(';') { + let internal = &type_desc[1..type_desc.len() - 1]; + descriptor::simple_class_name(internal).to_string() + } else { + type_desc.to_string() + }; + + let arguments: Vec = ann + .element_value_pairs + .iter() + .filter_map(|evp| { + let name = util::get_utf8(const_pool, evp.element_name_index)?.to_string(); + let value = convert_element_value(&evp.value, const_pool)?; + Some(AnnotationArgument::Named { name, value }) + }) + .collect(); + + Some(JavaAnnotation { + type_name, + arguments, + }) +} + +fn convert_element_value( + ev: &ElementValue, + const_pool: &[ConstantInfo], +) -> Option { + match ev { + ElementValue::ConstValueIndex(cv) => { + let idx = cv.value; + match cv.tag { + 'B' | 'C' | 'I' | 'S' | 'Z' => { + if let Some(ConstantInfo::Integer(c)) = + const_pool.get((idx as usize).checked_sub(1)?) + { + if cv.tag == 'Z' { + Some(AnnotationValue::BooleanLiteral(c.value != 0)) + } else if cv.tag == 'C' { + Some(AnnotationValue::CharLiteral(c.value as u8 as char)) + } else { + Some(AnnotationValue::IntLiteral(c.value)) + } + } else { + None + } + } + 'J' => { + if let Some(ConstantInfo::Long(c)) = + const_pool.get((idx as usize).checked_sub(1)?) + { + Some(AnnotationValue::LongLiteral(c.value)) + } else { + None + } + } + 'F' => { + if let Some(ConstantInfo::Float(c)) = + const_pool.get((idx as usize).checked_sub(1)?) + { + Some(AnnotationValue::FloatLiteral(c.value)) + } else { + None + } + } + 'D' => { + if let Some(ConstantInfo::Double(c)) = + const_pool.get((idx as usize).checked_sub(1)?) + { + Some(AnnotationValue::DoubleLiteral(c.value)) + } else { + None + } + } + 's' => { + let s = util::get_utf8(const_pool, idx)?.to_string(); + Some(AnnotationValue::StringLiteral(s)) + } + _ => None, + } + } + ElementValue::EnumConst(ec) => { + let type_name = util::get_utf8(const_pool, ec.type_name_index)?; + let const_name = util::get_utf8(const_pool, ec.const_name_index)?.to_string(); + let type_simple = if type_name.starts_with('L') && type_name.ends_with(';') { + descriptor::simple_class_name(&type_name[1..type_name.len() - 1]).to_string() + } else { + type_name.to_string() + }; + Some(AnnotationValue::EnumConstant { + type_name: type_simple, + const_name, + }) + } + ElementValue::ClassInfoIndex(idx) => { + let desc = util::get_utf8(const_pool, *idx)?.to_string(); + Some(AnnotationValue::ClassLiteral(desc)) + } + ElementValue::AnnotationValue(ann) => { + let a = convert_annotation(ann, const_pool)?; + Some(AnnotationValue::AnnotationLiteral(a)) + } + ElementValue::ElementArray(arr) => { + let values: Vec = arr + .values + .iter() + .filter_map(|v| convert_element_value(v, const_pool)) + .collect(); + Some(AnnotationValue::ArrayLiteral(values)) + } + } +} + +fn extract_class_type_parameters(_class: &ClassFile) -> Vec { + // TODO: Parse Signature attribute for class-level generic type parameters + // The signature format is: Ljava/lang/Object; + Vec::new() +} diff --git a/src/decompile/util.rs b/src/decompile/util.rs new file mode 100644 index 0000000..ea88ba2 --- /dev/null +++ b/src/decompile/util.rs @@ -0,0 +1,253 @@ +use crate::code_attribute::Instruction; +use crate::constant_info::ConstantInfo; + +/// Returns the byte size of an instruction in the code array. +/// `address` is the bytecode offset of this instruction (needed for switch alignment). +pub fn instruction_byte_size(instr: &Instruction, address: u32) -> u32 { + match instr { + Instruction::Nop => 1, + Instruction::Aconstnull => 1, + Instruction::Iconstm1 + | Instruction::Iconst0 + | Instruction::Iconst1 + | Instruction::Iconst2 + | Instruction::Iconst3 + | Instruction::Iconst4 + | Instruction::Iconst5 => 1, + Instruction::Lconst0 | Instruction::Lconst1 => 1, + Instruction::Fconst0 | Instruction::Fconst1 | Instruction::Fconst2 => 1, + Instruction::Dconst0 | Instruction::Dconst1 => 1, + Instruction::Bipush(_) => 2, + Instruction::Sipush(_) => 3, + Instruction::Ldc(_) => 2, + Instruction::LdcW(_) => 3, + Instruction::Ldc2W(_) => 3, + Instruction::Iload(_) + | Instruction::Lload(_) + | Instruction::Fload(_) + | Instruction::Dload(_) + | Instruction::Aload(_) => 2, + Instruction::Iload0 | Instruction::Iload1 | Instruction::Iload2 | Instruction::Iload3 => 1, + Instruction::Lload0 | Instruction::Lload1 | Instruction::Lload2 | Instruction::Lload3 => 1, + Instruction::Fload0 | Instruction::Fload1 | Instruction::Fload2 | Instruction::Fload3 => 1, + Instruction::Dload0 | Instruction::Dload1 | Instruction::Dload2 | Instruction::Dload3 => 1, + Instruction::Aload0 | Instruction::Aload1 | Instruction::Aload2 | Instruction::Aload3 => 1, + Instruction::Iaload + | Instruction::Laload + | Instruction::Faload + | Instruction::Daload + | Instruction::Aaload + | Instruction::Baload + | Instruction::Caload + | Instruction::Saload => 1, + Instruction::Istore(_) + | Instruction::Lstore(_) + | Instruction::Fstore(_) + | Instruction::Dstore(_) + | Instruction::Astore(_) => 2, + Instruction::Istore0 + | Instruction::Istore1 + | Instruction::Istore2 + | Instruction::Istore3 => 1, + Instruction::Lstore0 + | Instruction::Lstore1 + | Instruction::Lstore2 + | Instruction::Lstore3 => 1, + Instruction::Fstore0 + | Instruction::Fstore1 + | Instruction::Fstore2 + | Instruction::Fstore3 => 1, + Instruction::Dstore0 + | Instruction::Dstore1 + | Instruction::Dstore2 + | Instruction::Dstore3 => 1, + Instruction::Astore0 + | Instruction::Astore1 + | Instruction::Astore2 + | Instruction::Astore3 => 1, + Instruction::Iastore + | Instruction::Lastore + | Instruction::Fastore + | Instruction::Dastore + | Instruction::Aastore + | Instruction::Bastore + | Instruction::Castore + | Instruction::Sastore => 1, + Instruction::Pop => 1, + Instruction::Pop2 => 1, + Instruction::Dup => 1, + Instruction::Dupx1 => 1, + Instruction::Dupx2 => 1, + Instruction::Dup2 => 1, + Instruction::Dup2x1 => 1, + Instruction::Dup2x2 => 1, + Instruction::Swap => 1, + Instruction::Iadd | Instruction::Ladd | Instruction::Fadd | Instruction::Dadd => 1, + Instruction::Isub | Instruction::Lsub | Instruction::Fsub | Instruction::Dsub => 1, + Instruction::Imul | Instruction::Lmul | Instruction::Fmul | Instruction::Dmul => 1, + Instruction::Idiv | Instruction::Ldiv | Instruction::Fdiv | Instruction::Ddiv => 1, + Instruction::Irem | Instruction::Lrem | Instruction::Frem | Instruction::Drem => 1, + Instruction::Ineg | Instruction::Lneg | Instruction::Fneg | Instruction::Dneg => 1, + Instruction::Ishl | Instruction::Lshl => 1, + Instruction::Ishr | Instruction::Lshr => 1, + Instruction::Iushr | Instruction::Lushr => 1, + Instruction::Iand | Instruction::Land => 1, + Instruction::Ior | Instruction::Lor => 1, + Instruction::Ixor | Instruction::Lxor => 1, + Instruction::Iinc { .. } => 3, + Instruction::I2l | Instruction::I2f | Instruction::I2d => 1, + Instruction::L2i | Instruction::L2f | Instruction::L2d => 1, + Instruction::F2i | Instruction::F2l | Instruction::F2d => 1, + Instruction::D2i | Instruction::D2l | Instruction::D2f => 1, + Instruction::I2b | Instruction::I2c | Instruction::I2s => 1, + Instruction::Lcmp => 1, + Instruction::Fcmpl | Instruction::Fcmpg => 1, + Instruction::Dcmpl | Instruction::Dcmpg => 1, + Instruction::Ifeq(_) + | Instruction::Ifne(_) + | Instruction::Iflt(_) + | Instruction::Ifge(_) + | Instruction::Ifgt(_) + | Instruction::Ifle(_) => 3, + Instruction::IfIcmpeq(_) + | Instruction::IfIcmpne(_) + | Instruction::IfIcmplt(_) + | Instruction::IfIcmpge(_) + | Instruction::IfIcmpgt(_) + | Instruction::IfIcmple(_) => 3, + Instruction::IfAcmpeq(_) | Instruction::IfAcmpne(_) => 3, + Instruction::Goto(_) => 3, + Instruction::Jsr(_) => 3, + Instruction::Ret(_) => 2, + Instruction::Tableswitch { low, high, .. } => { + let padding = (4 - (address + 1) % 4) % 4; + // 1 (opcode) + padding + 4 (default) + 4 (low) + 4 (high) + 4*(high-low+1) + 1 + padding + 4 + 4 + 4 + 4 * (high - low + 1) as u32 + } + Instruction::Lookupswitch { npairs, .. } => { + let padding = (4 - (address + 1) % 4) % 4; + // 1 (opcode) + padding + 4 (default) + 4 (npairs) + 8*npairs + 1 + padding + 4 + 4 + 8 * npairs + } + Instruction::Getstatic(_) + | Instruction::Putstatic(_) + | Instruction::Getfield(_) + | Instruction::Putfield(_) => 3, + Instruction::Invokevirtual(_) + | Instruction::Invokespecial(_) + | Instruction::Invokestatic(_) => 3, + Instruction::Invokeinterface { .. } => 5, + Instruction::Invokedynamic { .. } => 5, + Instruction::New(_) => 3, + Instruction::Newarray(_) => 2, + Instruction::Anewarray(_) => 3, + Instruction::Arraylength => 1, + Instruction::Athrow => 1, + Instruction::Checkcast(_) => 3, + Instruction::Instanceof(_) => 3, + Instruction::Monitorenter | Instruction::Monitorexit => 1, + Instruction::Multianewarray { .. } => 4, + Instruction::Ifnull(_) | Instruction::Ifnonnull(_) => 3, + Instruction::GotoW(_) => 5, + Instruction::JsrW(_) => 5, + Instruction::Areturn + | Instruction::Ireturn + | Instruction::Lreturn + | Instruction::Freturn + | Instruction::Dreturn + | Instruction::Return => 1, + // wide instructions: 2 bytes magic + 2 bytes index + Instruction::IloadWide(_) + | Instruction::LloadWide(_) + | Instruction::FloadWide(_) + | Instruction::DloadWide(_) + | Instruction::AloadWide(_) => 4, + Instruction::IstoreWide(_) + | Instruction::LstoreWide(_) + | Instruction::FstoreWide(_) + | Instruction::DstoreWide(_) + | Instruction::AstoreWide(_) => 4, + Instruction::RetWide(_) => 4, + Instruction::IincWide { .. } => 6, + } +} + +/// Compute byte addresses for each instruction in a code array. +/// Returns Vec<(address, &Instruction)>. +pub fn compute_addresses(code: &[Instruction]) -> Vec<(u32, &Instruction)> { + let mut result = Vec::with_capacity(code.len()); + let mut address = 0u32; + for instr in code { + result.push((address, instr)); + address += instruction_byte_size(instr, address); + } + result +} + +/// Look up a UTF-8 constant pool entry by 1-based index. +pub fn get_utf8(const_pool: &[ConstantInfo], index: u16) -> Option<&str> { + match const_pool.get((index as usize).checked_sub(1)?)? { + ConstantInfo::Utf8(u) => Some(&u.utf8_string), + _ => None, + } +} + +/// Resolve a Class constant to its name string. +pub fn get_class_name(const_pool: &[ConstantInfo], class_index: u16) -> Option<&str> { + match const_pool.get((class_index as usize).checked_sub(1)?)? { + ConstantInfo::Class(c) => get_utf8(const_pool, c.name_index), + _ => None, + } +} + +/// Resolve a NameAndType constant to (name, descriptor). +pub fn get_name_and_type(const_pool: &[ConstantInfo], nat_index: u16) -> Option<(&str, &str)> { + match const_pool.get((nat_index as usize).checked_sub(1)?)? { + ConstantInfo::NameAndType(nat) => { + let name = get_utf8(const_pool, nat.name_index)?; + let desc = get_utf8(const_pool, nat.descriptor_index)?; + Some((name, desc)) + } + _ => None, + } +} + +/// Resolve a FieldRef, MethodRef, or InterfaceMethodRef to (class_name, method_name, descriptor). +pub fn resolve_ref(const_pool: &[ConstantInfo], index: u16) -> Option<(&str, &str, &str)> { + let entry = const_pool.get((index as usize).checked_sub(1)?)?; + let (class_index, nat_index) = match entry { + ConstantInfo::FieldRef(r) => (r.class_index, r.name_and_type_index), + ConstantInfo::MethodRef(r) => (r.class_index, r.name_and_type_index), + ConstantInfo::InterfaceMethodRef(r) => (r.class_index, r.name_and_type_index), + _ => return None, + }; + let class_name = get_class_name(const_pool, class_index)?; + let (name, desc) = get_name_and_type(const_pool, nat_index)?; + Some((class_name, name, desc)) +} + +/// Get a constant pool entry's value as a string for display. +pub fn format_constant(const_pool: &[ConstantInfo], index: u16) -> String { + match const_pool.get((index as usize).wrapping_sub(1)) { + Some(ConstantInfo::Integer(c)) => format!("{}", c.value), + Some(ConstantInfo::Float(c)) => format!("{}f", c.value), + Some(ConstantInfo::Long(c)) => format!("{}L", c.value), + Some(ConstantInfo::Double(c)) => format!("{}d", c.value), + Some(ConstantInfo::String(c)) => { + if let Some(s) = get_utf8(const_pool, c.string_index) { + format!("\"{}\"", s) + } else { + format!("", c.string_index) + } + } + Some(ConstantInfo::Class(c)) => { + if let Some(name) = get_utf8(const_pool, c.name_index) { + name.to_string() + } else { + format!("", c.name_index) + } + } + Some(ConstantInfo::Utf8(c)) => c.utf8_string.clone(), + _ => format!("", index), + } +} diff --git a/src/field_info/mod.rs b/src/field_info/mod.rs index 62a9f11..6108ca4 100644 --- a/src/field_info/mod.rs +++ b/src/field_info/mod.rs @@ -1,5 +1,3 @@ -mod parser; mod types; -pub use self::parser::field_parser; pub use self::types::*; diff --git a/src/field_info/parser.rs b/src/field_info/parser.rs deleted file mode 100644 index 1a7a9dd..0000000 --- a/src/field_info/parser.rs +++ /dev/null @@ -1,23 +0,0 @@ -use nom::{IResult, multi::count, number::complete::be_u16}; - -use crate::attribute_info::attribute_parser; - -use crate::field_info::{FieldAccessFlags, FieldInfo}; - -pub fn field_parser(input: &[u8]) -> IResult<&[u8], FieldInfo> { - let (input, access_flags) = be_u16(input)?; - let (input, name_index) = be_u16(input)?; - let (input, descriptor_index) = be_u16(input)?; - let (input, attributes_count) = be_u16(input)?; - let (input, attributes) = count(attribute_parser, attributes_count as usize)(input)?; - Ok(( - input, - FieldInfo { - access_flags: FieldAccessFlags::from_bits_truncate(access_flags), - name_index, - descriptor_index, - attributes_count, - attributes, - }, - )) -} diff --git a/src/field_info/types.rs b/src/field_info/types.rs index 91cf144..2fd6b39 100644 --- a/src/field_info/types.rs +++ b/src/field_info/types.rs @@ -1,4 +1,4 @@ -use crate::attribute_info::AttributeInfo; +use crate::{InterpretInner, attribute_info::AttributeInfo}; use binrw::binrw; #[derive(Clone, Debug)] @@ -9,10 +9,18 @@ pub struct FieldInfo { pub name_index: u16, pub descriptor_index: u16, pub attributes_count: u16, - #[br(args { count: attributes_count.into() })] + #[br(count = attributes_count)] pub attributes: Vec, } +impl InterpretInner for FieldInfo { + fn interpret_inner(&mut self, const_pool: &Vec) { + for attr in &mut self.attributes { + attr.interpret_inner(const_pool); + } + } +} + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] #[binrw] pub struct FieldAccessFlags(u16); @@ -31,13 +39,3 @@ bitflags! { const ENUM = 0x4000; // Declared as an element of an enum. } } - -#[cfg(test)] -#[allow(dead_code)] -trait TraitTester: - Copy + Clone + PartialEq + Eq + PartialOrd + Ord + ::std::hash::Hash + ::std::fmt::Debug -{ -} - -#[cfg(test)] -impl TraitTester for FieldAccessFlags {} diff --git a/src/jar_patch/mod.rs b/src/jar_patch/mod.rs new file mode 100644 index 0000000..e289f0e --- /dev/null +++ b/src/jar_patch/mod.rs @@ -0,0 +1,242 @@ +use std::fmt; + +use crate::compile::{CompileError, CompileOptions, compile_method_body}; +use crate::jar_utils::{JarError, JarFile}; + +// --------------------------------------------------------------------------- +// Error type +// --------------------------------------------------------------------------- + +#[derive(Debug)] +pub enum JarPatchError { + Jar(JarError), + Compile(CompileError), +} + +impl fmt::Display for JarPatchError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + JarPatchError::Jar(e) => write!(f, "jar error: {e}"), + JarPatchError::Compile(e) => write!(f, "compile error: {e}"), + } + } +} + +impl std::error::Error for JarPatchError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + JarPatchError::Jar(e) => Some(e), + JarPatchError::Compile(e) => Some(e), + } + } +} + +impl From for JarPatchError { + fn from(e: JarError) -> Self { + JarPatchError::Jar(e) + } +} + +impl From for JarPatchError { + fn from(e: CompileError) -> Self { + JarPatchError::Compile(e) + } +} + +pub type JarPatchResult = Result; + +// --------------------------------------------------------------------------- +// Functions +// --------------------------------------------------------------------------- + +/// Compile Java source and replace a single method's body in a class inside +/// a JAR. +/// +/// Parses the class from the JAR, patches the specified method with the +/// compiled source, and writes the modified class back. +pub fn patch_jar_method( + jar: &mut JarFile, + class_path: &str, + method_name: &str, + source: &str, + options: &CompileOptions, +) -> JarPatchResult<()> { + let mut class_file = jar.parse_class(class_path)?; + compile_method_body(source, &mut class_file, method_name, None, options)?; + jar.set_class(class_path, &class_file)?; + Ok(()) +} + +/// Compile and patch multiple methods in a single class inside a JAR. +/// +/// The class is parsed once, all methods are patched in order, and the +/// modified class is written back once. If any method fails, the error is +/// returned immediately and the JAR entry is not updated. +pub fn patch_jar_class( + jar: &mut JarFile, + class_path: &str, + patches: &[(&str, &str)], + options: &CompileOptions, +) -> JarPatchResult<()> { + let mut class_file = jar.parse_class(class_path)?; + for &(method_name, source) in patches { + compile_method_body(source, &mut class_file, method_name, None, options)?; + } + jar.set_class(class_path, &class_file)?; + Ok(()) +} + +/// Compile and patch methods across multiple classes in a JAR. +/// +/// Each entry is `(class_path, &[(method_name, source)])`. Classes are +/// processed one at a time. If any class fails, the error is returned +/// immediately. +pub fn patch_jar_classes( + jar: &mut JarFile, + patches: &[(&str, &[(&str, &str)])], + options: &CompileOptions, +) -> JarPatchResult<()> { + for &(class_path, method_patches) in patches { + patch_jar_class(jar, class_path, method_patches, options)?; + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// Macros +// --------------------------------------------------------------------------- + +/// Compile and patch a single method in a class inside a JAR. +/// +/// Generates a valid StackMapTable by default so the patched class passes +/// full JVM bytecode verification. +/// +/// # Forms +/// +/// ```ignore +/// // With StackMapTable (default): +/// patch_jar_method!(jar, "com/example/Main.class", "main", r#"{ ... }"#)?; +/// +/// // Without StackMapTable (requires -noverify): +/// patch_jar_method!(jar, "com/example/Main.class", "main", r#"{ ... }"#, no_verify)?; +/// ``` +#[macro_export] +macro_rules! patch_jar_method { + ($jar:expr, $class_path:expr, $method:expr, $source:expr) => { + $crate::jar_patch::patch_jar_method( + &mut $jar, + $class_path, + $method, + $source, + &$crate::compile::CompileOptions { + generate_stack_map_table: true, + ..$crate::compile::CompileOptions::default() + }, + ) + }; + ($jar:expr, $class_path:expr, $method:expr, $source:expr, no_verify) => { + $crate::jar_patch::patch_jar_method( + &mut $jar, + $class_path, + $method, + $source, + &$crate::compile::CompileOptions::default(), + ) + }; +} + +/// Compile and patch multiple methods in a single class inside a JAR. +/// +/// The class is parsed once, all methods are patched, and the class is +/// written back once. Generates a valid StackMapTable by default. +/// +/// ```ignore +/// patch_jar_class!(jar, "com/example/Main.class", { +/// "main" => r#"{ System.out.println("hello"); }"#, +/// "helper" => r#"{ return 42; }"#, +/// })?; +/// +/// // Without StackMapTable: +/// patch_jar_class!(jar, "com/example/Main.class", no_verify, { +/// "main" => r#"{ ... }"#, +/// })?; +/// ``` +#[macro_export] +macro_rules! patch_jar_class { + ($jar:expr, $class_path:expr, { $($method:expr => $source:expr),+ $(,)? }) => { + $crate::jar_patch::patch_jar_class( + &mut $jar, + $class_path, + &[ $( ($method, $source) ),+ ], + &$crate::compile::CompileOptions { + generate_stack_map_table: true, + ..$crate::compile::CompileOptions::default() + }, + ) + }; + ($jar:expr, $class_path:expr, no_verify, { $($method:expr => $source:expr),+ $(,)? }) => { + $crate::jar_patch::patch_jar_class( + &mut $jar, + $class_path, + &[ $( ($method, $source) ),+ ], + &$crate::compile::CompileOptions::default(), + ) + }; +} + +/// Compile and patch methods across multiple classes in a JAR. +/// +/// Each class is parsed once, its methods are patched, and the class is +/// written back before moving to the next. Generates a valid StackMapTable +/// by default. +/// +/// ```ignore +/// patch_jar!(jar, { +/// "com/example/Main.class" => { +/// "main" => r#"{ System.out.println("hello"); }"#, +/// "helper" => r#"{ return 42; }"#, +/// }, +/// "com/example/Util.class" => { +/// "compute" => r#"{ return 0; }"#, +/// }, +/// })?; +/// +/// // Without StackMapTable: +/// patch_jar!(jar, no_verify, { +/// "com/example/Main.class" => { +/// "main" => r#"{ ... }"#, +/// }, +/// })?; +/// ``` +#[macro_export] +macro_rules! patch_jar { + ($jar:expr, { $($class_path:expr => { $($method:expr => $source:expr),+ $(,)? }),+ $(,)? }) => {{ + (|| -> Result<(), $crate::jar_patch::JarPatchError> { + $( + $crate::jar_patch::patch_jar_class( + &mut $jar, + $class_path, + &[ $( ($method, $source) ),+ ], + &$crate::compile::CompileOptions { + generate_stack_map_table: true, + ..$crate::compile::CompileOptions::default() + }, + )?; + )+ + Ok(()) + })() + }}; + ($jar:expr, no_verify, { $($class_path:expr => { $($method:expr => $source:expr),+ $(,)? }),+ $(,)? }) => {{ + (|| -> Result<(), $crate::jar_patch::JarPatchError> { + $( + $crate::jar_patch::patch_jar_class( + &mut $jar, + $class_path, + &[ $( ($method, $source) ),+ ], + &$crate::compile::CompileOptions::default(), + )?; + )+ + Ok(()) + })() + }}; +} diff --git a/src/jar_utils/manifest.rs b/src/jar_utils/manifest.rs new file mode 100644 index 0000000..f3c7404 --- /dev/null +++ b/src/jar_utils/manifest.rs @@ -0,0 +1,249 @@ +use super::types::{JarError, JarResult}; + +/// Ordered collection of key-value pairs with case-insensitive key lookup. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ManifestAttributes { + entries: Vec<(String, String)>, +} + +impl ManifestAttributes { + pub fn new() -> Self { + Self { + entries: Vec::new(), + } + } + + /// Case-insensitive key lookup. + pub fn get(&self, key: &str) -> Option<&str> { + self.entries + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case(key)) + .map(|(_, v)| v.as_str()) + } + + /// Replace if a matching key exists (case-insensitive), otherwise append. + pub fn set(&mut self, key: impl Into, value: impl Into) { + let key = key.into(); + let value = value.into(); + if let Some(entry) = self + .entries + .iter_mut() + .find(|(k, _)| k.eq_ignore_ascii_case(&key)) + { + entry.0 = key; + entry.1 = value; + } else { + self.entries.push((key, value)); + } + } + + /// Remove the first entry matching the key (case-insensitive). Returns the value if found. + pub fn remove(&mut self, key: &str) -> Option { + let pos = self + .entries + .iter() + .position(|(k, _)| k.eq_ignore_ascii_case(key))?; + Some(self.entries.remove(pos).1) + } + + pub fn contains_key(&self, key: &str) -> bool { + self.entries + .iter() + .any(|(k, _)| k.eq_ignore_ascii_case(key)) + } + + pub fn iter(&self) -> impl Iterator { + self.entries.iter().map(|(k, v)| (k.as_str(), v.as_str())) + } + + pub fn len(&self) -> usize { + self.entries.len() + } + + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } +} + +impl Default for ManifestAttributes { + fn default() -> Self { + Self::new() + } +} + +/// Structured representation of a JAR manifest (`META-INF/MANIFEST.MF`). +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct JarManifest { + pub main_attributes: ManifestAttributes, + pub entries: std::collections::BTreeMap, +} + +impl JarManifest { + /// Parse a manifest from raw bytes (assumed UTF-8). + pub fn parse(data: &[u8]) -> JarResult { + let text = std::str::from_utf8(data) + .map_err(|e| JarError::ManifestParse(format!("invalid UTF-8: {e}")))?; + + // Split into lines, handling both \r\n and \n + let raw_lines: Vec<&str> = text.split('\n').collect(); + + // Join continuation lines (lines starting with a single space) + let mut logical_lines: Vec = Vec::new(); + for raw in &raw_lines { + let line = raw.strip_suffix('\r').unwrap_or(raw); + if line.starts_with(' ') && !logical_lines.is_empty() { + // Continuation line — append without the leading space + let last = logical_lines.last_mut().unwrap(); + last.push_str(&line[1..]); + } else { + logical_lines.push(line.to_string()); + } + } + + let mut main_attributes = ManifestAttributes::new(); + let mut entries = std::collections::BTreeMap::new(); + let mut current_section: Option<(String, ManifestAttributes)> = None; + let mut in_main = true; + + for line in &logical_lines { + if line.is_empty() { + // Blank line separates sections + if let Some((name, attrs)) = current_section.take() { + entries.insert(name, attrs); + } + in_main = false; + continue; + } + + // Split on first ": " + let Some(colon_pos) = line.find(": ") else { + // Lines without ": " are ignored (e.g. trailing whitespace) + continue; + }; + let key = &line[..colon_pos]; + let value = &line[colon_pos + 2..]; + + if in_main { + main_attributes.set(key, value); + } else if key.eq_ignore_ascii_case("Name") && current_section.is_none() { + current_section = Some((value.to_string(), ManifestAttributes::new())); + } else if let Some((_, ref mut attrs)) = current_section { + attrs.set(key, value); + } else { + // Name: starts a new section + if key.eq_ignore_ascii_case("Name") { + current_section = Some((value.to_string(), ManifestAttributes::new())); + } + } + } + + // Flush last section + if let Some((name, attrs)) = current_section { + entries.insert(name, attrs); + } + + Ok(JarManifest { + main_attributes, + entries, + }) + } + + /// Serialize to bytes with \r\n line endings and 72-byte line wrapping. + pub fn to_bytes(&self) -> Vec { + let mut out = String::new(); + + // Main section + write_section(&mut out, &self.main_attributes); + + // Per-entry sections + for (name, attrs) in &self.entries { + out.push_str("\r\n"); + write_wrapped_line(&mut out, "Name", name); + write_section(&mut out, attrs); + } + + out.into_bytes() + } + + /// Shorthand for `main_attributes.get(key)`. + pub fn main_attr(&self, key: &str) -> Option<&str> { + self.main_attributes.get(key) + } + + /// Shorthand for `main_attributes.set(key, value)`. + pub fn set_main_attr(&mut self, key: impl Into, value: impl Into) { + self.main_attributes.set(key, value); + } + + /// Get a per-entry section by name. + pub fn entry_section(&self, name: &str) -> Option<&ManifestAttributes> { + self.entries.get(name) + } + + /// Get or create a per-entry section by name (mutable). + pub fn entry_section_mut(&mut self, name: impl Into) -> &mut ManifestAttributes { + self.entries.entry(name.into()).or_default() + } + + /// Create a default manifest with `Manifest-Version: 1.0`. + pub fn default_manifest() -> Self { + let mut main_attributes = ManifestAttributes::new(); + main_attributes.set("Manifest-Version", "1.0"); + JarManifest { + main_attributes, + entries: std::collections::BTreeMap::new(), + } + } +} + +/// Write all attributes in a section, each line wrapped at 72 bytes. +fn write_section(out: &mut String, attrs: &ManifestAttributes) { + for (key, value) in attrs.iter() { + write_wrapped_line(out, key, value); + } +} + +/// Write a single `Key: Value\r\n` with continuation wrapping at 72 bytes. +fn write_wrapped_line(out: &mut String, key: &str, value: &str) { + let full = format!("{}: {}", key, value); + let bytes = full.as_bytes(); + + if bytes.len() <= 72 { + out.push_str(&full); + out.push_str("\r\n"); + return; + } + + // First line: up to 72 bytes + // Find a safe UTF-8 boundary at or before byte 72 + let first_end = safe_split_pos(&full, 72); + out.push_str(&full[..first_end]); + out.push_str("\r\n"); + + let mut pos = first_end; + while pos < bytes.len() { + // Continuation lines: " " + up to 71 bytes of content = 72 bytes total + let chunk_end = safe_split_pos(&full[pos..], 71); + out.push(' '); + out.push_str(&full[pos..pos + chunk_end]); + out.push_str("\r\n"); + pos += chunk_end; + } +} + +/// Find the largest byte position <= max_bytes that is a valid UTF-8 char boundary. +fn safe_split_pos(s: &str, max_bytes: usize) -> usize { + if max_bytes >= s.len() { + return s.len(); + } + let mut pos = max_bytes; + while pos > 0 && !s.is_char_boundary(pos) { + pos -= 1; + } + // Don't return 0 unless the string is empty — force at least one char + if pos == 0 && !s.is_empty() { + let first_char_len = s.chars().next().unwrap().len_utf8(); + return first_char_len; + } + pos +} diff --git a/src/jar_utils/mod.rs b/src/jar_utils/mod.rs new file mode 100644 index 0000000..9d6b3b9 --- /dev/null +++ b/src/jar_utils/mod.rs @@ -0,0 +1,5 @@ +mod manifest; +mod types; + +pub use self::manifest::*; +pub use self::types::*; diff --git a/src/jar_utils/types.rs b/src/jar_utils/types.rs new file mode 100644 index 0000000..0640cca --- /dev/null +++ b/src/jar_utils/types.rs @@ -0,0 +1,236 @@ +use std::collections::BTreeMap; +use std::io::{Cursor, Read, Seek, Write}; +use std::path::Path; + +use binrw::{BinRead, BinWrite}; +use zip::CompressionMethod; +use zip::write::SimpleFileOptions; + +use crate::ClassFile; + +use super::manifest::JarManifest; + +const MANIFEST_PATH: &str = "META-INF/MANIFEST.MF"; + +// --------------------------------------------------------------------------- +// Error type +// --------------------------------------------------------------------------- + +#[derive(Debug)] +pub enum JarError { + Io(std::io::Error), + Zip(zip::result::ZipError), + ClassParse(binrw::Error), + ManifestParse(String), +} + +impl std::fmt::Display for JarError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JarError::Io(e) => write!(f, "I/O error: {e}"), + JarError::Zip(e) => write!(f, "ZIP error: {e}"), + JarError::ClassParse(e) => write!(f, "class parse error: {e}"), + JarError::ManifestParse(e) => write!(f, "manifest parse error: {e}"), + } + } +} + +impl std::error::Error for JarError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + JarError::Io(e) => Some(e), + JarError::Zip(e) => Some(e), + JarError::ClassParse(e) => Some(e), + JarError::ManifestParse(_) => None, + } + } +} + +impl From for JarError { + fn from(e: std::io::Error) -> Self { + JarError::Io(e) + } +} + +impl From for JarError { + fn from(e: zip::result::ZipError) -> Self { + JarError::Zip(e) + } +} + +impl From for JarError { + fn from(e: binrw::Error) -> Self { + JarError::ClassParse(e) + } +} + +pub type JarResult = Result; + +// --------------------------------------------------------------------------- +// JarFile +// --------------------------------------------------------------------------- + +/// In-memory representation of a JAR (ZIP) archive. +/// +/// Entries are stored as a `BTreeMap>` mapping entry paths to +/// raw bytes. This avoids lifetime issues with `ZipArchive` and allows free +/// mutation before writing. +#[derive(Clone, Debug)] +pub struct JarFile { + entries: BTreeMap>, +} + +impl JarFile { + /// Create an empty JAR. + pub fn new() -> Self { + JarFile { + entries: BTreeMap::new(), + } + } + + // -- Reading -- + + /// Read a JAR from any reader. + pub fn read(reader: R) -> JarResult { + let mut archive = zip::ZipArchive::new(reader)?; + let mut entries = BTreeMap::new(); + + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + if file.is_dir() { + continue; + } + let name = file.name().to_string(); + let mut data = Vec::with_capacity(file.size() as usize); + file.read_to_end(&mut data)?; + entries.insert(name, data); + } + + Ok(JarFile { entries }) + } + + /// Read a JAR from a byte slice. + pub fn from_bytes(bytes: &[u8]) -> JarResult { + Self::read(Cursor::new(bytes)) + } + + /// Read a JAR from a file path. + pub fn open(path: impl AsRef) -> JarResult { + let file = std::fs::File::open(path)?; + let reader = std::io::BufReader::new(file); + Self::read(reader) + } + + // -- Writing -- + + /// Write the JAR to any writer using Deflated compression. + pub fn write(&self, writer: W) -> JarResult<()> { + let mut zip_writer = zip::ZipWriter::new(writer); + let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated); + + for (name, data) in &self.entries { + zip_writer.start_file(name, options)?; + zip_writer.write_all(data)?; + } + + zip_writer.finish()?; + Ok(()) + } + + /// Serialize the JAR to a byte vector. + pub fn to_bytes(&self) -> JarResult> { + let mut buf = Cursor::new(Vec::new()); + self.write(&mut buf)?; + Ok(buf.into_inner()) + } + + /// Write the JAR to a file path. + pub fn save(&self, path: impl AsRef) -> JarResult<()> { + let file = std::fs::File::create(path)?; + let writer = std::io::BufWriter::new(file); + self.write(writer) + } + + // -- Entry access -- + + /// Iterate over all entry paths (sorted). + pub fn entry_names(&self) -> impl Iterator { + self.entries.keys().map(|s| s.as_str()) + } + + /// Iterate over `.class` entry paths only. + pub fn class_names(&self) -> impl Iterator { + self.entry_names().filter(|n| n.ends_with(".class")) + } + + /// Get the raw bytes of an entry. + pub fn get_entry(&self, path: &str) -> Option<&[u8]> { + self.entries.get(path).map(|v| v.as_slice()) + } + + /// Insert or replace an entry. + pub fn set_entry(&mut self, path: impl Into, data: Vec) { + self.entries.insert(path.into(), data); + } + + /// Remove an entry, returning its data if it existed. + pub fn remove_entry(&mut self, path: &str) -> Option> { + self.entries.remove(path) + } + + /// Check whether an entry exists. + pub fn contains_entry(&self, path: &str) -> bool { + self.entries.contains_key(path) + } + + // -- ClassFile integration -- + + /// Parse a `.class` entry into a `ClassFile`. + pub fn parse_class(&self, path: &str) -> JarResult { + let data = self.get_entry(path).ok_or_else(|| { + JarError::Io(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("entry not found: {path}"), + )) + })?; + let mut cursor = Cursor::new(data); + let class_file = ClassFile::read(&mut cursor)?; + Ok(class_file) + } + + /// Parse all `.class` entries. Returns a vec of `(path, result)` pairs. + pub fn parse_all_classes(&self) -> Vec<(String, JarResult)> { + self.class_names() + .map(|name| name.to_string()) + .collect::>() + .into_iter() + .map(|name| { + let result = self.parse_class(&name); + (name, result) + }) + .collect() + } + + /// Serialize a `ClassFile` and store it as an entry. + pub fn set_class(&mut self, path: &str, class_file: &ClassFile) -> JarResult<()> { + let mut buf = Cursor::new(Vec::new()); + class_file.write(&mut buf)?; + self.set_entry(path.to_string(), buf.into_inner()); + Ok(()) + } + + // -- Manifest integration -- + + /// Parse the `META-INF/MANIFEST.MF` entry if present. + pub fn manifest(&self) -> JarResult> { + match self.get_entry(MANIFEST_PATH) { + Some(data) => Ok(Some(JarManifest::parse(data)?)), + None => Ok(None), + } + } + + /// Serialize a `JarManifest` and store it as `META-INF/MANIFEST.MF`. + pub fn set_manifest(&mut self, manifest: &JarManifest) { + self.set_entry(MANIFEST_PATH.to_string(), manifest.to_bytes()); + } +} diff --git a/src/lib.rs b/src/lib.rs index f0ad88e..f0780e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,5 @@ //! A parser for [Java Classfiles](https://docs.oracle.com/javase/specs/jvms/se10/html/jvms-4.html) -use std::fs::File; -use std::io::{BufReader, prelude::*}; -use std::path::Path; - -#[macro_use] -extern crate nom; - #[macro_use] extern crate bitflags; @@ -17,67 +10,21 @@ pub mod method_info; pub mod code_attribute; -pub mod parser; pub mod types; -pub use parser::class_parser; pub use types::*; -/// Attempt to parse a class file given a path to a class file (without .class extension) -/// -/// ```rust -/// match classfile_parser::parse_class("./java-assets/compiled-classes/BasicClass") { -/// Ok(class_file) => { -/// println!("version {},{}", class_file.major_version, class_file.minor_version); -/// } -/// Err(ex) => panic!("Failed to parse: {}", ex), -/// }; -/// ``` -pub fn parse_class(class_name: &str) -> Result { - let class_file_name = &format!("{}.class", class_name); - let path = Path::new(class_file_name); - let display = path.display(); - - let file = match File::open(path) { - Err(why) => { - return Err(format!("Unable to open {}: {}", display, &why.to_string())); - } - Ok(file) => file, - }; +#[cfg(feature = "decompile")] +pub mod decompile; - let mut reader = BufReader::new(file); - parse_class_from_reader(&mut reader) -} +#[cfg(feature = "compile")] +pub mod compile; -/// Attempt to parse a class file given a reader that implements the std::io::Read trait. -/// Parameters shouldn't be passed for the sole purpose of debug output, this should be -/// abstracted instead. -/// OLD: The file_path parameter is only used in case of errors to provide -/// reasonable error messages. -/// -/// ```rust -/// let mut reader = "this_will_be_parsed_as_classfile".as_bytes(); -/// let result = classfile_parser::parse_class_from_reader(&mut reader); -/// assert!(result.is_err()); -/// ``` -pub fn parse_class_from_reader(reader: &mut T) -> Result { - let mut class_bytes = Vec::new(); - reader - .read_to_end(&mut class_bytes) - .expect("cannot continue, read_to_end failed"); +#[cfg(feature = "jar-utils")] +pub mod jar_utils; - let parsed_class = class_parser(&class_bytes); - match parsed_class { - Ok((a, c)) => { - if !a.is_empty() { - eprintln!( - "Warning: not all bytes were consumed when parsing classfile, {} bytes remaining", - a.len() - ); - } +#[cfg(all(feature = "compile", feature = "jar-utils"))] +pub mod jar_patch; - Ok(c) - } - Err(e) => Err(format!("Failed to parse classfile: {}", e)), - } -} +#[cfg(feature = "spring-utils")] +pub mod spring_utils; diff --git a/src/method_info/mod.rs b/src/method_info/mod.rs index b69dec1..6108ca4 100644 --- a/src/method_info/mod.rs +++ b/src/method_info/mod.rs @@ -1,5 +1,3 @@ -mod parser; mod types; -pub use self::parser::method_parser; pub use self::types::*; diff --git a/src/method_info/parser.rs b/src/method_info/parser.rs deleted file mode 100644 index 0f66fd0..0000000 --- a/src/method_info/parser.rs +++ /dev/null @@ -1,23 +0,0 @@ -use nom::{IResult, multi::count, number::complete::be_u16}; - -use crate::attribute_info::attribute_parser; - -use crate::method_info::{MethodAccessFlags, MethodInfo}; - -pub fn method_parser(input: &[u8]) -> IResult<&[u8], MethodInfo> { - let (input, access_flags) = be_u16(input)?; - let (input, name_index) = be_u16(input)?; - let (input, descriptor_index) = be_u16(input)?; - let (input, attributes_count) = be_u16(input)?; - let (input, attributes) = count(attribute_parser, attributes_count as usize)(input)?; - Ok(( - input, - MethodInfo { - access_flags: MethodAccessFlags::from_bits_truncate(access_flags), - name_index, - descriptor_index, - attributes_count, - attributes, - }, - )) -} diff --git a/src/method_info/types.rs b/src/method_info/types.rs index 411950c..72bd0c8 100644 --- a/src/method_info/types.rs +++ b/src/method_info/types.rs @@ -1,6 +1,9 @@ -use crate::attribute_info::AttributeInfo; +use crate::{ + InterpretInner, + attribute_info::{AttributeInfo, AttributeInfoVariant, CodeAttribute}, +}; -use binrw::binrw; +use binrw::{BinResult, binrw}; #[derive(Clone, Debug)] #[binrw] @@ -10,10 +13,73 @@ pub struct MethodInfo { pub name_index: u16, pub descriptor_index: u16, pub attributes_count: u16, - #[br(args { count: attributes_count.into() })] + #[br(count = attributes_count)] pub attributes: Vec, } +impl InterpretInner for MethodInfo { + fn interpret_inner(&mut self, const_pool: &Vec) { + for attr in &mut self.attributes { + attr.interpret_inner(const_pool); + } + } +} + +impl MethodInfo { + /// Returns a reference to the Code attribute, if present. + pub fn code(&self) -> Option<&CodeAttribute> { + self.attributes.iter().find_map(|a| match &a.info_parsed { + Some(AttributeInfoVariant::Code(c)) => Some(c), + _ => None, + }) + } + + /// Returns a mutable reference to the Code attribute, if present. + pub fn code_mut(&mut self) -> Option<&mut CodeAttribute> { + self.attributes + .iter_mut() + .find_map(|a| match &mut a.info_parsed { + Some(AttributeInfoVariant::Code(c)) => Some(c), + _ => None, + }) + } + + /// Returns a reference to the AttributeInfo containing the Code attribute. + /// Useful when you need to call `sync_from_parsed()` after modifying the code. + pub fn code_attribute_info(&self) -> Option<&AttributeInfo> { + self.attributes + .iter() + .find(|a| matches!(&a.info_parsed, Some(AttributeInfoVariant::Code(_)))) + } + + /// Returns a mutable reference to the AttributeInfo containing the Code attribute. + /// Useful when you need to call `sync_from_parsed()` after modifying the code. + pub fn code_attribute_info_mut(&mut self) -> Option<&mut AttributeInfo> { + self.attributes + .iter_mut() + .find(|a| matches!(&a.info_parsed, Some(AttributeInfoVariant::Code(_)))) + } + + /// Access the CodeAttribute inside a closure and auto-sync when done. + /// + /// Returns `None` if no Code attribute exists (e.g. abstract method), + /// `Some(Err(_))` if sync fails, `Some(Ok(R))` on success. + pub fn with_code(&mut self, f: F) -> Option> + where + F: FnOnce(&mut CodeAttribute) -> R, + { + let attr = self + .attributes + .iter_mut() + .find(|a| matches!(&a.info_parsed, Some(AttributeInfoVariant::Code(_))))?; + let result = match &mut attr.info_parsed { + Some(AttributeInfoVariant::Code(code)) => f(code), + _ => unreachable!(), + }; + Some(attr.sync_from_parsed().map(|()| result)) + } +} + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] #[binrw] pub struct MethodAccessFlags(u16); @@ -34,13 +100,3 @@ bitflags! { const SYNTHETIC = 0x1000; // Declared synthetic; not present in the source code. } } - -#[cfg(test)] -#[allow(dead_code)] -trait TraitTester: - Copy + Clone + PartialEq + Eq + PartialOrd + Ord + ::std::hash::Hash + ::std::fmt::Debug -{ -} - -#[cfg(test)] -impl TraitTester for MethodAccessFlags {} diff --git a/src/parser.rs b/src/parser.rs deleted file mode 100644 index 4bc75ef..0000000 --- a/src/parser.rs +++ /dev/null @@ -1,69 +0,0 @@ -use nom::*; - -use crate::attribute_info::attribute_parser; -use crate::constant_info::constant_parser; -use crate::field_info::field_parser; -use crate::method_info::method_parser; -use crate::types::{ClassAccessFlags, ClassFile}; -use nom::bytes::complete::tag; -use nom::multi::count; - -fn magic_parser(input: &[u8]) -> IResult<&[u8], &[u8]> { - tag(&[0xCA, 0xFE, 0xBA, 0xBE])(input) -} - -/// Parse a byte array into a ClassFile. This will probably be deprecated in 0.4.0 in as it returns -/// a nom IResult type, which exposes the internal parsing library and not a good idea. -/// -/// If you want to call it directly, as it is the only way to parse a byte slice directly, you must -/// unwrap the result yourself. -/// -/// ```rust -/// let classfile_bytes = include_bytes!("../java-assets/compiled-classes/BasicClass.class"); -/// -/// match classfile_parser::class_parser(classfile_bytes) { -/// Ok((_, class_file)) => { -/// println!("version {},{}", class_file.major_version, class_file.minor_version); -/// } -/// Err(_) => panic!("Failed to parse"), -/// }; -/// ``` -pub fn class_parser(input: &[u8]) -> IResult<&[u8], ClassFile> { - use nom::number::complete::be_u16; - let (input, _) = magic_parser(input)?; - let (input, minor_version) = be_u16(input)?; - let (input, major_version) = be_u16(input)?; - let (input, const_pool_size) = be_u16(input)?; - let (input, const_pool) = constant_parser(input, (const_pool_size - 1) as usize)?; - let (input, access_flags) = be_u16(input)?; - let (input, this_class) = be_u16(input)?; - let (input, super_class) = be_u16(input)?; - let (input, interfaces_count) = be_u16(input)?; - let (input, interfaces) = count(be_u16, interfaces_count as usize)(input)?; - let (input, fields_count) = be_u16(input)?; - let (input, fields) = count(field_parser, fields_count as usize)(input)?; - let (input, methods_count) = be_u16(input)?; - let (input, methods) = count(method_parser, methods_count as usize)(input)?; - let (input, attributes_count) = be_u16(input)?; - let (input, attributes) = count(attribute_parser, attributes_count as usize)(input)?; - Ok(( - input, - ClassFile { - minor_version, - major_version, - const_pool_size, - const_pool, - access_flags: ClassAccessFlags::from_bits_truncate(access_flags), - this_class, - super_class, - interfaces_count, - interfaces, - fields_count, - fields, - methods_count, - methods, - attributes_count, - attributes, - }, - )) -} diff --git a/src/spring_utils/classpath_idx.rs b/src/spring_utils/classpath_idx.rs new file mode 100644 index 0000000..5d5ee43 --- /dev/null +++ b/src/spring_utils/classpath_idx.rs @@ -0,0 +1,71 @@ +use crate::jar_utils::{JarError, JarResult}; + +/// Parsed representation of a Spring Boot `classpath.idx` file. +/// +/// Format: each line is `- "BOOT-INF/lib/some.jar"` +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ClasspathIndex { + entries: Vec, +} + +impl ClasspathIndex { + /// Parse a `classpath.idx` file from raw bytes. + pub fn parse(data: &[u8]) -> JarResult { + let text = std::str::from_utf8(data) + .map_err(|e| JarError::ManifestParse(format!("classpath.idx: invalid UTF-8: {e}")))?; + + let mut entries = Vec::new(); + for line in text.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + // Expected format: - "path" + let Some(rest) = line.strip_prefix("- \"") else { + continue; + }; + let Some(path) = rest.strip_suffix('"') else { + continue; + }; + entries.push(path.to_string()); + } + + Ok(ClasspathIndex { entries }) + } + + /// Serialize back to `classpath.idx` format. + pub fn to_bytes(&self) -> Vec { + let mut out = String::new(); + for entry in &self.entries { + out.push_str("- \""); + out.push_str(entry); + out.push_str("\"\n"); + } + out.into_bytes() + } + + /// All classpath entries. + pub fn entries(&self) -> &[String] { + &self.entries + } + + /// Number of entries. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Whether the index is empty. + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + /// Check if a path is in the classpath index. + pub fn contains(&self, path: &str) -> bool { + self.entries.iter().any(|e| e == path) + } + + /// Iterate over entries. + pub fn iter(&self) -> impl Iterator { + self.entries.iter().map(|s| s.as_str()) + } +} diff --git a/src/spring_utils/layers_idx.rs b/src/spring_utils/layers_idx.rs new file mode 100644 index 0000000..2ad6060 --- /dev/null +++ b/src/spring_utils/layers_idx.rs @@ -0,0 +1,125 @@ +use crate::jar_utils::{JarError, JarResult}; + +/// A single layer in a Spring Boot `layers.idx` file. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Layer { + pub name: String, + pub paths: Vec, +} + +/// Parsed representation of a Spring Boot `layers.idx` file. +/// +/// Format: +/// ```text +/// - "dependencies": +/// - "BOOT-INF/lib/" +/// - "application": +/// - "BOOT-INF/classes/" +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LayersIndex { + layers: Vec, +} + +impl LayersIndex { + /// Parse a `layers.idx` file from raw bytes. + pub fn parse(data: &[u8]) -> JarResult { + let text = std::str::from_utf8(data) + .map_err(|e| JarError::ManifestParse(format!("layers.idx: invalid UTF-8: {e}")))?; + + let mut layers = Vec::new(); + let mut current: Option = None; + + for line in text.lines() { + if line.is_empty() { + continue; + } + + // Layer header: - "name": + if line.starts_with("- \"") { + // Flush previous layer + if let Some(layer) = current.take() { + layers.push(layer); + } + let rest = &line[3..]; // after `- "` + if let Some(name) = rest.strip_suffix("\":") { + current = Some(Layer { + name: name.to_string(), + paths: Vec::new(), + }); + } + continue; + } + + // Path entry: - "path" + if line.starts_with(" - \"") { + let rest = &line[5..]; // after ` - "` + if let Some(path) = rest.strip_suffix('"') { + if let Some(ref mut layer) = current { + layer.paths.push(path.to_string()); + } + } + } + } + + // Flush last layer + if let Some(layer) = current { + layers.push(layer); + } + + Ok(LayersIndex { layers }) + } + + /// Serialize back to `layers.idx` format. + pub fn to_bytes(&self) -> Vec { + let mut out = String::new(); + for layer in &self.layers { + out.push_str("- \""); + out.push_str(&layer.name); + out.push_str("\":\n"); + for path in &layer.paths { + out.push_str(" - \""); + out.push_str(path); + out.push_str("\"\n"); + } + } + out.into_bytes() + } + + /// All layers. + pub fn layers(&self) -> &[Layer] { + &self.layers + } + + /// Find a layer by name. + pub fn find_layer(&self, name: &str) -> Option<&Layer> { + self.layers.iter().find(|l| l.name == name) + } + + /// Iterate over layer names. + pub fn layer_names(&self) -> impl Iterator { + self.layers.iter().map(|l| l.name.as_str()) + } + + /// Find which layer an entry path belongs to (by prefix match). + pub fn layer_for_path(&self, entry_path: &str) -> Option<&str> { + for layer in &self.layers { + for path in &layer.paths { + if entry_path.starts_with(path.as_str()) { + return Some(&layer.name); + } + } + } + None + } + + /// Number of layers. + pub fn len(&self) -> usize { + self.layers.len() + } + + /// Whether there are no layers. + pub fn is_empty(&self) -> bool { + self.layers.is_empty() + } +} diff --git a/src/spring_utils/mod.rs b/src/spring_utils/mod.rs new file mode 100644 index 0000000..b6942b2 --- /dev/null +++ b/src/spring_utils/mod.rs @@ -0,0 +1,7 @@ +mod classpath_idx; +mod layers_idx; +mod types; + +pub use self::classpath_idx::*; +pub use self::layers_idx::*; +pub use self::types::*; diff --git a/src/spring_utils/types.rs b/src/spring_utils/types.rs new file mode 100644 index 0000000..49d75c3 --- /dev/null +++ b/src/spring_utils/types.rs @@ -0,0 +1,300 @@ +use std::io::{Read, Seek}; +use std::path::Path; + +use crate::ClassFile; +use crate::jar_utils::{JarError, JarFile, JarResult}; + +use super::classpath_idx::ClasspathIndex; +use super::layers_idx::LayersIndex; + +// --------------------------------------------------------------------------- +// SpringBootFormat +// --------------------------------------------------------------------------- + +/// The packaging format of a Spring Boot fat JAR/WAR. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SpringBootFormat { + /// Standard JAR layout — application classes in `BOOT-INF/`. + Jar, + /// WAR layout — application classes in `WEB-INF/`. + War, +} + +impl SpringBootFormat { + /// The top-level prefix directory (`BOOT-INF` or `WEB-INF`). + pub fn prefix(&self) -> &'static str { + match self { + SpringBootFormat::Jar => "BOOT-INF", + SpringBootFormat::War => "WEB-INF", + } + } + + /// The directory containing application classes. + pub fn classes_dir(&self) -> &'static str { + match self { + SpringBootFormat::Jar => "BOOT-INF/classes/", + SpringBootFormat::War => "WEB-INF/classes/", + } + } + + /// The directory containing dependency JARs. + pub fn lib_dir(&self) -> &'static str { + match self { + SpringBootFormat::Jar => "BOOT-INF/lib/", + SpringBootFormat::War => "WEB-INF/lib/", + } + } +} + +// --------------------------------------------------------------------------- +// Detection +// --------------------------------------------------------------------------- + +/// Known Spring Boot launcher main classes. +const JAR_LAUNCHERS: &[&str] = &[ + "org.springframework.boot.loader.JarLauncher", + "org.springframework.boot.loader.launch.JarLauncher", +]; + +const WAR_LAUNCHERS: &[&str] = &[ + "org.springframework.boot.loader.WarLauncher", + "org.springframework.boot.loader.launch.WarLauncher", +]; + +const PROPERTIES_LAUNCHERS: &[&str] = &[ + "org.springframework.boot.loader.PropertiesLauncher", + "org.springframework.boot.loader.launch.PropertiesLauncher", +]; + +/// Detect the Spring Boot format from a `JarFile` by inspecting the manifest. +/// +/// Returns `None` if this is not a Spring Boot fat JAR/WAR. +pub fn detect_format(jar: &JarFile) -> Option { + let manifest = jar.manifest().ok()??; + + // Must have Start-Class + manifest.main_attr("Start-Class")?; + + let main_class = manifest.main_attr("Main-Class")?; + + if JAR_LAUNCHERS.iter().any(|&l| l == main_class) { + return Some(SpringBootFormat::Jar); + } + if WAR_LAUNCHERS.iter().any(|&l| l == main_class) { + return Some(SpringBootFormat::War); + } + if PROPERTIES_LAUNCHERS.iter().any(|&l| l == main_class) { + // Infer format from directory structure + if jar.entry_names().any(|n| n.starts_with("BOOT-INF/")) { + return Some(SpringBootFormat::Jar); + } + if jar.entry_names().any(|n| n.starts_with("WEB-INF/")) { + return Some(SpringBootFormat::War); + } + // Default to Jar for PropertiesLauncher + return Some(SpringBootFormat::Jar); + } + + None +} + +// --------------------------------------------------------------------------- +// SpringBootJar +// --------------------------------------------------------------------------- + +/// A Spring Boot fat JAR/WAR wrapping a `JarFile`. +#[derive(Clone, Debug)] +pub struct SpringBootJar { + jar: JarFile, + format: SpringBootFormat, +} + +impl SpringBootJar { + // -- Construction / Detection -- + + /// Wrap an existing `JarFile` if it is a Spring Boot fat JAR/WAR. + pub fn from_jar(jar: JarFile) -> Option { + let format = detect_format(&jar)?; + Some(SpringBootJar { jar, format }) + } + + /// Read a JAR and detect Spring Boot format. + /// Returns `Ok(None)` if it is not a Spring Boot fat JAR. + pub fn read(reader: R) -> JarResult> { + let jar = JarFile::read(reader)?; + Ok(Self::from_jar(jar)) + } + + /// Read from bytes and detect. + pub fn from_bytes(bytes: &[u8]) -> JarResult> { + let jar = JarFile::from_bytes(bytes)?; + Ok(Self::from_jar(jar)) + } + + /// Open from a file path and detect. + pub fn open(path: impl AsRef) -> JarResult> { + let jar = JarFile::open(path)?; + Ok(Self::from_jar(jar)) + } + + // -- Accessors -- + + /// The underlying `JarFile`. + pub fn jar(&self) -> &JarFile { + &self.jar + } + + /// Mutable access to the underlying `JarFile`. + pub fn jar_mut(&mut self) -> &mut JarFile { + &mut self.jar + } + + /// Consume the wrapper, returning the inner `JarFile`. + pub fn into_jar(self) -> JarFile { + self.jar + } + + /// The detected packaging format. + pub fn format(&self) -> SpringBootFormat { + self.format + } + + // -- Manifest shortcuts -- + + /// The `Start-Class` manifest attribute (the actual application main class). + pub fn start_class(&self) -> JarResult> { + Ok(self + .jar + .manifest()? + .and_then(|m| m.main_attr("Start-Class").map(|s| s.to_string()))) + } + + /// The `Spring-Boot-Version` manifest attribute. + pub fn spring_boot_version(&self) -> JarResult> { + Ok(self + .jar + .manifest()? + .and_then(|m| m.main_attr("Spring-Boot-Version").map(|s| s.to_string()))) + } + + /// The `Spring-Boot-Classes` manifest attribute. + pub fn spring_boot_classes_path(&self) -> JarResult> { + Ok(self + .jar + .manifest()? + .and_then(|m| m.main_attr("Spring-Boot-Classes").map(|s| s.to_string()))) + } + + /// The `Spring-Boot-Lib` manifest attribute. + pub fn spring_boot_lib_path(&self) -> JarResult> { + Ok(self + .jar + .manifest()? + .and_then(|m| m.main_attr("Spring-Boot-Lib").map(|s| s.to_string()))) + } + + // -- Application classes -- + + /// Iterate over `.class` file paths under the classes directory. + pub fn app_class_names(&self) -> impl Iterator { + let classes_dir = self.format.classes_dir(); + self.jar + .entry_names() + .filter(move |n| n.starts_with(classes_dir) && n.ends_with(".class")) + } + + /// Iterate over non-`.class` resource paths under the classes directory. + pub fn app_resource_names(&self) -> impl Iterator { + let classes_dir = self.format.classes_dir(); + self.jar + .entry_names() + .filter(move |n| n.starts_with(classes_dir) && !n.ends_with(".class")) + } + + /// Parse a `.class` file from the classes directory. + pub fn parse_app_class(&self, path: &str) -> JarResult { + self.jar.parse_class(path) + } + + /// Parse all `.class` files under the classes directory. + pub fn parse_all_app_classes(&self) -> Vec<(String, JarResult)> { + self.app_class_names() + .map(|n| n.to_string()) + .collect::>() + .into_iter() + .map(|name| { + let result = self.jar.parse_class(&name); + (name, result) + }) + .collect() + } + + // -- Loader classes -- + + /// Iterate over Spring Boot loader class paths. + pub fn loader_class_names(&self) -> impl Iterator { + self.jar + .entry_names() + .filter(|n| n.starts_with("org/springframework/boot/loader/") && n.ends_with(".class")) + } + + // -- Nested JARs -- + + /// Iterate over nested JAR paths under the lib directory. + pub fn nested_jar_names(&self) -> impl Iterator { + let lib_dir = self.format.lib_dir(); + self.jar + .entry_names() + .filter(move |n| n.starts_with(lib_dir) && n.ends_with(".jar")) + } + + /// Open a nested JAR by its path within the fat JAR. + pub fn open_nested_jar(&self, path: &str) -> JarResult { + let data = self.jar.get_entry(path).ok_or_else(|| { + JarError::Io(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("nested JAR not found: {path}"), + )) + })?; + JarFile::from_bytes(data) + } + + /// Open all nested JARs. Returns `(path, result)` pairs. + pub fn open_all_nested_jars(&self) -> Vec<(String, JarResult)> { + self.nested_jar_names() + .map(|n| n.to_string()) + .collect::>() + .into_iter() + .map(|name| { + let result = self.open_nested_jar(&name); + (name, result) + }) + .collect() + } + + /// Parse a `.class` file from inside a nested JAR. + pub fn parse_nested_class(&self, jar_path: &str, class_path: &str) -> JarResult { + let nested = self.open_nested_jar(jar_path)?; + nested.parse_class(class_path) + } + + // -- Index file access -- + + /// Parse the `classpath.idx` file if present. + pub fn classpath_index(&self) -> JarResult> { + let idx_path = format!("{}/classpath.idx", self.format.prefix()); + match self.jar.get_entry(&idx_path) { + Some(data) => Ok(Some(ClasspathIndex::parse(data)?)), + None => Ok(None), + } + } + + /// Parse the `layers.idx` file if present. + pub fn layers_index(&self) -> JarResult> { + let idx_path = format!("{}/layers.idx", self.format.prefix()); + match self.jar.get_entry(&idx_path) { + Some(data) => Ok(Some(LayersIndex::parse(data)?)), + None => Ok(None), + } + } +} diff --git a/src/types.rs b/src/types.rs index 8616790..74db8da 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,36 +1,652 @@ +use std::io::{Read, Seek}; + use crate::attribute_info::AttributeInfo; -use crate::constant_info::ConstantInfo; +use crate::constant_info::{ + ClassConstant, ConstantInfo, DoubleConstant, FieldRefConstant, FloatConstant, IntegerConstant, + InterfaceMethodRefConstant, InvokeDynamicConstant, LongConstant, MethodHandleConstant, + MethodRefConstant, MethodTypeConstant, NameAndTypeConstant, StringConstant, Utf8Constant, +}; use crate::field_info::FieldInfo; use crate::method_info::MethodInfo; -use binrw::binrw; +use binrw::{ + BinRead, BinResult, BinWrite, Endian, VecArgs, binrw, + meta::{EndianKind, ReadEndian}, +}; -#[derive(Clone, Debug)] -#[binrw] +/// Custom writer for the constant pool that skips Unusable sentinel entries. +/// +/// On read, Long and Double constants occupy two slots in the constant pool, +/// and we insert an `Unusable` placeholder for the second slot. On write, +/// we must skip these placeholders since they are not part of the binary format. +#[binrw::writer(writer, endian)] +fn write_const_pool(pool: &Vec) -> BinResult<()> { + for item in pool { + if !matches!(item, ConstantInfo::Unusable) { + item.write_options(writer, endian, ())?; + } + } + Ok(()) +} + +#[derive(BinWrite, Clone, Debug)] #[brw(big, magic = b"\xca\xfe\xba\xbe")] pub struct ClassFile { pub minor_version: u16, pub major_version: u16, pub const_pool_size: u16, - #[br(args { count: const_pool_size.into() })] + #[bw(write_with = write_const_pool)] pub const_pool: Vec, pub access_flags: ClassAccessFlags, pub this_class: u16, pub super_class: u16, pub interfaces_count: u16, - #[br(args { count: interfaces_count.into() })] pub interfaces: Vec, pub fields_count: u16, - #[br(args { count: fields_count.into() })] pub fields: Vec, pub methods_count: u16, - #[br(args { count: methods_count.into() })] pub methods: Vec, pub attributes_count: u16, - #[br(args { count: attributes_count.into() })] pub attributes: Vec, } +pub trait InterpretInner { + fn interpret_inner(&mut self, const_pool: &Vec); +} + +impl ReadEndian for ClassFile { + const ENDIAN: EndianKind = EndianKind::Endian(Endian::Big); +} + +fn const_pool_parser( + r: &mut R, + endian: Endian, + args: VecArgs<()>, +) -> BinResult> { + let count = args.count.saturating_sub(1); + // Each CP entry is at least 3 bytes (1 tag + 2 data). + validate_count_vs_remaining(r, count, 3, "constant_pool_count")?; + let mut v = vec![]; + while v.len() < count { + v.push(ConstantInfo::read_options(r, endian, args.inner)?); + if matches!( + v.last().unwrap(), + ConstantInfo::Double(_) | ConstantInfo::Long(_) + ) { + v.push(ConstantInfo::Unusable); + } + } + + Ok(v) +} + +/// Validate that `count * min_entry_size` fits within the remaining data. +fn validate_count_vs_remaining( + r: &mut R, + count: usize, + min_entry_size: usize, + label: &str, +) -> BinResult<()> { + if count == 0 { + return Ok(()); + } + let pos = r.stream_position().map_err(binrw::Error::Io)?; + let end = r + .seek(std::io::SeekFrom::End(0)) + .map_err(binrw::Error::Io)?; + r.seek(std::io::SeekFrom::Start(pos)) + .map_err(binrw::Error::Io)?; + let remaining = end.saturating_sub(pos) as usize; + if count * min_entry_size > remaining { + return Err(binrw::Error::AssertFail { + pos, + message: format!( + "{} {} requires at least {} bytes but only {} remain", + label, + count, + count * min_entry_size, + remaining + ), + }); + } + Ok(()) +} + +impl BinRead for ClassFile { + type Args<'a> = (); + + fn read_options( + reader: &mut R, + _endian: binrw::Endian, + _args: Self::Args<'_>, + ) -> binrw::BinResult { + let magic = u32::read_options(reader, Endian::Big, ())?; + if magic != u32::from_be_bytes([0xca, 0xfe, 0xba, 0xbe]) { + return Err(binrw::Error::BadMagic { + pos: 0, + found: Box::new(magic), + }); + } + + let minor_version = u16::read_options(reader, Endian::Big, ())?; + let major_version = u16::read_options(reader, Endian::Big, ())?; + let const_pool_size = u16::read_options(reader, Endian::Big, ())?; + let const_pool = const_pool_parser( + reader, + Endian::Big, + VecArgs { + count: const_pool_size as usize, + inner: (), + }, + )?; + + let access_flags = ClassAccessFlags::read_options(reader, Endian::Big, ())?; + let this_class = u16::read_options(reader, Endian::Big, ())?; + let super_class = u16::read_options(reader, Endian::Big, ())?; + let interfaces_count = u16::read_options(reader, Endian::Big, ())?; + // Each interface is a u16 (2 bytes) + validate_count_vs_remaining(reader, interfaces_count as usize, 2, "interfaces_count")?; + let interfaces = Vec::::read_options( + reader, + Endian::Big, + VecArgs { + count: interfaces_count as usize, + inner: (), + }, + )?; + let fields_count = u16::read_options(reader, Endian::Big, ())?; + // Each field is at least 8 bytes (flags + name_idx + desc_idx + attr_count) + validate_count_vs_remaining(reader, fields_count as usize, 8, "fields_count")?; + let mut fields = Vec::::read_options( + reader, + Endian::Big, + VecArgs { + count: fields_count as usize, + inner: (), + }, + )?; + + let methods_count = u16::read_options(reader, Endian::Big, ())?; + // Each method is at least 8 bytes (flags + name_idx + desc_idx + attr_count) + validate_count_vs_remaining(reader, methods_count as usize, 8, "methods_count")?; + let mut methods = Vec::::read_options( + reader, + Endian::Big, + VecArgs { + count: methods_count as usize, + inner: (), + }, + )?; + + let attributes_count = u16::read_options(reader, Endian::Big, ())?; + // Each attribute is at least 6 bytes (name_idx + length) + validate_count_vs_remaining(reader, attributes_count as usize, 6, "attributes_count")?; + let mut attributes = Vec::::read_options( + reader, + Endian::Big, + VecArgs { + count: attributes_count as usize, + inner: (), + }, + )?; + + for field in &mut fields { + field.interpret_inner(&const_pool); + } + + for method in &mut methods { + method.interpret_inner(&const_pool); + } + + for attr in &mut attributes { + attr.interpret_inner(&const_pool); + } + + Ok(ClassFile { + minor_version, + major_version, + const_pool_size, + const_pool, + access_flags, + this_class, + super_class, + interfaces_count, + interfaces, + fields_count, + fields, + methods_count, + methods, + attributes_count, + attributes, + }) + } +} + +impl ClassFile { + /// Recalculates all count fields from actual vector lengths. + /// Call this after adding or removing entries from const_pool, interfaces, + /// fields, methods, or attributes. + pub fn sync_counts(&mut self) { + fn checked_u16(val: usize, field: &str) -> u16 { + u16::try_from(val) + .unwrap_or_else(|_| panic!("{} count {} exceeds u16::MAX", field, val)) + } + self.const_pool_size = checked_u16(self.const_pool.len() + 1, "const_pool"); + self.interfaces_count = checked_u16(self.interfaces.len(), "interfaces"); + self.fields_count = checked_u16(self.fields.len(), "fields"); + self.methods_count = checked_u16(self.methods.len(), "methods"); + self.attributes_count = checked_u16(self.attributes.len(), "attributes"); + } + + /// Look up a UTF-8 constant pool entry by its 1-based index. + /// Returns `None` if the index is out of range or does not point to a Utf8 entry. + pub fn get_utf8(&self, index: u16) -> Option<&str> { + match self.const_pool.get((index - 1) as usize)? { + ConstantInfo::Utf8(u) => Some(&u.utf8_string), + _ => None, + } + } + + /// Find the 1-based constant pool index of a UTF-8 entry matching the given string. + pub fn find_utf8_index(&self, value: &str) -> Option { + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::Utf8(u) = entry + && u.utf8_string == value + { + return Some((i + 1) as u16); + } + } + None + } + + /// Find a method by name. + pub fn find_method(&self, name: &str) -> Option<&MethodInfo> { + self.methods + .iter() + .find(|m| self.get_utf8(m.name_index) == Some(name)) + } + + /// Find a method by name, returning a mutable reference. + pub fn find_method_mut(&mut self, name: &str) -> Option<&mut MethodInfo> { + let idx = self + .methods + .iter() + .position(|m| self.get_utf8(m.name_index) == Some(name))?; + Some(&mut self.methods[idx]) + } + + /// Find a field by name. + pub fn find_field(&self, name: &str) -> Option<&FieldInfo> { + self.fields + .iter() + .find(|f| self.get_utf8(f.name_index) == Some(name)) + } + + /// Find a field by name, returning a mutable reference. + pub fn find_field_mut(&mut self, name: &str) -> Option<&mut FieldInfo> { + let idx = self + .fields + .iter() + .position(|f| self.get_utf8(f.name_index) == Some(name))?; + Some(&mut self.fields[idx]) + } + + /// Add a UTF-8 constant to the pool. Returns the 1-based index. + /// Always adds a new entry (no dedup). Does NOT call `sync_counts()`. + pub fn add_utf8(&mut self, value: &str) -> u16 { + let index = (self.const_pool.len() + 1) as u16; + self.const_pool.push(ConstantInfo::Utf8(Utf8Constant { + utf8_string: String::from(value), + })); + index + } + + /// Get or add a UTF-8 constant. Returns existing index if found, otherwise adds. + pub fn get_or_add_utf8(&mut self, value: &str) -> u16 { + if let Some(idx) = self.find_utf8_index(value) { + idx + } else { + self.add_utf8(value) + } + } + + /// Add a String constant (Utf8 + String pair). Returns the String constant's 1-based index. + pub fn add_string(&mut self, value: &str) -> u16 { + let utf8_index = self.add_utf8(value); + let string_index = (self.const_pool.len() + 1) as u16; + self.const_pool.push(ConstantInfo::String(StringConstant { + string_index: utf8_index, + })); + string_index + } + + /// Get or add a String constant, deduplicating both the Utf8 and String entries. + pub fn get_or_add_string(&mut self, value: &str) -> u16 { + let utf8_index = self.get_or_add_utf8(value); + // Search for an existing String constant pointing to this Utf8 + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::String(s) = entry + && s.string_index == utf8_index + { + return (i + 1) as u16; + } + } + let string_index = (self.const_pool.len() + 1) as u16; + self.const_pool.push(ConstantInfo::String(StringConstant { + string_index: utf8_index, + })); + string_index + } + + /// Add a Class constant (Utf8 + Class pair). `name` in internal form (e.g. `"java/lang/String"`). + /// Returns the Class constant's 1-based index. + pub fn add_class(&mut self, name: &str) -> u16 { + let utf8_index = self.add_utf8(name); + let class_index = (self.const_pool.len() + 1) as u16; + self.const_pool.push(ConstantInfo::Class(ClassConstant { + name_index: utf8_index, + })); + class_index + } + + /// Get or add a Class constant, deduplicating both the Utf8 and Class entries. + pub fn get_or_add_class(&mut self, name: &str) -> u16 { + let utf8_index = self.get_or_add_utf8(name); + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::Class(c) = entry + && c.name_index == utf8_index + { + return (i + 1) as u16; + } + } + let class_index = (self.const_pool.len() + 1) as u16; + self.const_pool.push(ConstantInfo::Class(ClassConstant { + name_index: utf8_index, + })); + class_index + } + + /// Add a NameAndType constant. Utf8 entries for `name` and `descriptor` are deduped + /// via `get_or_add_utf8`. Always adds a new NameAndType entry. + pub fn add_name_and_type(&mut self, name: &str, descriptor: &str) -> u16 { + let name_index = self.get_or_add_utf8(name); + let descriptor_index = self.get_or_add_utf8(descriptor); + let nat_index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::NameAndType(NameAndTypeConstant { + name_index, + descriptor_index, + })); + nat_index + } + + /// Get or add a NameAndType constant, deduplicating. + pub fn get_or_add_name_and_type(&mut self, name: &str, descriptor: &str) -> u16 { + let name_index = self.get_or_add_utf8(name); + let descriptor_index = self.get_or_add_utf8(descriptor); + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::NameAndType(nat) = entry + && nat.name_index == name_index + && nat.descriptor_index == descriptor_index + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::NameAndType(NameAndTypeConstant { + name_index, + descriptor_index, + })); + index + } + + /// Get or add a MethodRef constant, deduplicating. + pub fn get_or_add_method_ref( + &mut self, + class_name: &str, + method_name: &str, + descriptor: &str, + ) -> u16 { + let class_index = self.get_or_add_class(class_name); + let nat_index = self.get_or_add_name_and_type(method_name, descriptor); + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::MethodRef(r) = entry + && r.class_index == class_index + && r.name_and_type_index == nat_index + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::MethodRef(MethodRefConstant { + class_index, + name_and_type_index: nat_index, + })); + index + } + + /// Get or add a FieldRef constant, deduplicating. + pub fn get_or_add_field_ref( + &mut self, + class_name: &str, + field_name: &str, + descriptor: &str, + ) -> u16 { + let class_index = self.get_or_add_class(class_name); + let nat_index = self.get_or_add_name_and_type(field_name, descriptor); + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::FieldRef(r) = entry + && r.class_index == class_index + && r.name_and_type_index == nat_index + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::FieldRef(FieldRefConstant { + class_index, + name_and_type_index: nat_index, + })); + index + } + + /// Get or add an InterfaceMethodRef constant, deduplicating. + pub fn get_or_add_interface_method_ref( + &mut self, + class_name: &str, + method_name: &str, + descriptor: &str, + ) -> u16 { + let class_index = self.get_or_add_class(class_name); + let nat_index = self.get_or_add_name_and_type(method_name, descriptor); + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::InterfaceMethodRef(r) = entry + && r.class_index == class_index + && r.name_and_type_index == nat_index + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool.push(ConstantInfo::InterfaceMethodRef( + InterfaceMethodRefConstant { + class_index, + name_and_type_index: nat_index, + }, + )); + index + } + + /// Get or add an Integer constant, deduplicating. + pub fn get_or_add_integer(&mut self, value: i32) -> u16 { + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::Integer(c) = entry + && c.value == value + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::Integer(IntegerConstant { value })); + index + } + + /// Get or add a Float constant, deduplicating. + pub fn get_or_add_float(&mut self, value: f32) -> u16 { + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::Float(c) = entry + && c.value.to_bits() == value.to_bits() + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::Float(FloatConstant { value })); + index + } + + /// Get or add a Long constant, deduplicating. Adds Unusable sentinel. + pub fn get_or_add_long(&mut self, value: i64) -> u16 { + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::Long(c) = entry + && c.value == value + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::Long(LongConstant { value })); + self.const_pool.push(ConstantInfo::Unusable); + index + } + + /// Get or add a Double constant, deduplicating. Adds Unusable sentinel. + pub fn get_or_add_double(&mut self, value: f64) -> u16 { + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::Double(c) = entry + && c.value.to_bits() == value.to_bits() + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::Double(DoubleConstant { value })); + self.const_pool.push(ConstantInfo::Unusable); + index + } + + pub fn get_or_add_method_handle(&mut self, reference_kind: u8, reference_index: u16) -> u16 { + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::MethodHandle(c) = entry + && c.reference_kind == reference_kind + && c.reference_index == reference_index + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::MethodHandle(MethodHandleConstant { + reference_kind, + reference_index, + })); + index + } + + pub fn get_or_add_method_type(&mut self, descriptor: &str) -> u16 { + let desc_idx = self.get_or_add_utf8(descriptor); + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::MethodType(c) = entry + && c.descriptor_index == desc_idx + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::MethodType(MethodTypeConstant { + descriptor_index: desc_idx, + })); + index + } + + pub fn get_or_add_invoke_dynamic( + &mut self, + bootstrap_method_attr_index: u16, + name: &str, + descriptor: &str, + ) -> u16 { + let nat_idx = self.get_or_add_name_and_type(name, descriptor); + for (i, entry) in self.const_pool.iter().enumerate() { + if let ConstantInfo::InvokeDynamic(c) = entry + && c.bootstrap_method_attr_index == bootstrap_method_attr_index + && c.name_and_type_index == nat_idx + { + return (i + 1) as u16; + } + } + let index = (self.const_pool.len() + 1) as u16; + self.const_pool + .push(ConstantInfo::InvokeDynamic(InvokeDynamicConstant { + bootstrap_method_attr_index, + name_and_type_index: nat_idx, + })); + index + } + + /// Sync everything after a patching session: calls `sync_from_parsed()` on all + /// attributes (methods, fields, class-level), then `sync_counts()`. + pub fn sync_all(&mut self) -> BinResult<()> { + for method in &mut self.methods { + for attr in &mut method.attributes { + attr.sync_from_parsed()?; + } + } + for field in &mut self.fields { + for attr in &mut field.attributes { + attr.sync_from_parsed()?; + } + } + for attr in &mut self.attributes { + attr.sync_from_parsed()?; + } + self.sync_counts(); + Ok(()) + } + + /// Parse a `ClassFile` from raw `.class` bytes. + /// + /// ```no_run + /// let bytes = std::fs::read("HelloWorld.class").unwrap(); + /// let class_file = classfile_parser::ClassFile::from_bytes(&bytes).unwrap(); + /// ``` + pub fn from_bytes(bytes: &[u8]) -> BinResult { + use std::io::Cursor; + Self::read(&mut Cursor::new(bytes)) + } + + /// Serialize this `ClassFile` back to `.class` bytes. + /// + /// ```no_run + /// # let class_file = classfile_parser::ClassFile::from_bytes(&[]).unwrap(); + /// let bytes = class_file.to_bytes().unwrap(); + /// std::fs::write("HelloWorld.class", bytes).unwrap(); + /// ``` + pub fn to_bytes(&self) -> BinResult> { + use std::io::Cursor; + let mut out = Cursor::new(Vec::new()); + self.write(&mut out)?; + Ok(out.into_inner()) + } +} + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] #[binrw] pub struct ClassAccessFlags(u16); diff --git a/tests/attr_bootstrap_methods.rs b/tests/attr_bootstrap_methods.rs index a419e9f..2f77c99 100644 --- a/tests/attr_bootstrap_methods.rs +++ b/tests/attr_bootstrap_methods.rs @@ -1,16 +1,18 @@ extern crate classfile_parser; -extern crate nom; -use classfile_parser::attribute_info::bootstrap_methods_attribute_parser; -use classfile_parser::class_parser; -use classfile_parser::constant_info::ConstantInfo; +use std::fs::File; + +use binrw::{BinRead, io::BufReader}; +use classfile_parser::attribute_info::AttributeInfoVariant; +use classfile_parser::{ClassFile, constant_info::ConstantInfo}; #[test] fn test_attribute_bootstrap_methods() { - match class_parser(include_bytes!( - "../java-assets/compiled-classes/BootstrapMethods.class" - )) { - Ok((_, c)) => { + let bootstrap_methods_class = + File::open("java-assets/compiled-classes/BootstrapMethods.class").unwrap(); + let res = ClassFile::read(&mut BufReader::new(bootstrap_methods_class)); + match res { + Ok(c) => { println!( "Valid class file, version {},{} const_pool({}), this=const[{}], super=const[{}], interfaces({}), fields({}), methods({}), attributes({}), access({:?})", c.major_version, @@ -48,8 +50,8 @@ fn test_attribute_bootstrap_methods() { for attribute_item in c.attributes.iter() { if attribute_item.attribute_name_index == bootstrap_method_const_index { - match bootstrap_methods_attribute_parser(&attribute_item.info) { - Ok((_, bsma)) => { + match attribute_item.info_parsed.as_ref() { + Some(AttributeInfoVariant::BootstrapMethods(bsma)) => { assert_eq!(bsma.num_bootstrap_methods, 1); let bsm = &bsma.bootstrap_methods[0]; assert_eq!(bsm.num_bootstrap_arguments, 3); @@ -68,10 +70,11 @@ fn test_attribute_bootstrap_methods() { #[test] fn should_have_no_bootstrap_method_attr_if_no_invoke_dynamic() { - match class_parser(include_bytes!( - "../java-assets/compiled-classes/BasicClass.class" - )) { - Ok((_, c)) => { + let bootstrap_methods_class = + File::open("java-assets/compiled-classes/BasicClass.class").unwrap(); + let res = ClassFile::read(&mut BufReader::new(bootstrap_methods_class)); + match res { + Ok(c) => { for const_item in c.const_pool.iter() { if let ConstantInfo::Utf8(ref c) = *const_item && c.utf8_string.to_string() == "BootstrapMethods" diff --git a/tests/attr_stack_map_table.rs b/tests/attr_stack_map_table.rs index 6fb776d..47da9b4 100644 --- a/tests/attr_stack_map_table.rs +++ b/tests/attr_stack_map_table.rs @@ -1,30 +1,33 @@ extern crate classfile_parser; -extern crate nom; -use classfile_parser::attribute_info::AttributeInfo; -use classfile_parser::class_parser; +use std::fs::File; +use std::io::Cursor; +use std::io::prelude::*; + +use binrw::prelude::*; +use classfile_parser::ClassFile; +use classfile_parser::attribute_info::{AttributeInfoVariant, StackMapFrameInner}; use classfile_parser::constant_info::ConstantInfo; #[test] fn test_attribute_stack_map_table() { - let stack_map_class = include_bytes!("../java-assets/compiled-classes/Factorial.class"); - let res = class_parser(stack_map_class); + let mut contents: Vec = Vec::new(); + let mut stack_map_class = File::open("java-assets/compiled-classes/Factorial.class").unwrap(); + stack_map_class.read_to_end(&mut contents).unwrap(); + let res = ClassFile::read(&mut Cursor::new(&mut contents)); match res { - Ok((_, c)) => { - use classfile_parser::attribute_info::code_attribute_parser; - use classfile_parser::attribute_info::stack_map_table_attribute_parser; - + Ok(c) => { let mut stack_map_table_index = 0; println!("Constant pool:"); for (const_index, const_item) in c.const_pool.iter().enumerate() { println!("\t[{}] = {:?}", (const_index + 1), const_item); - if let ConstantInfo::Utf8(ref c) = *const_item - && c.utf8_string.to_string() == "StackMapTable" - { - if stack_map_table_index != 0 { - panic!("Should not find more than one StackMapTable constant"); + if let ConstantInfo::Utf8(ref c) = *const_item { + if c.utf8_string == "StackMapTable" { + if stack_map_table_index != 0 { + panic!("Should not find more than one StackMapTable constant"); + } + stack_map_table_index = (const_index + 1) as u16; } - stack_map_table_index = (const_index + 1) as u16; } } println!("Methods:"); @@ -39,46 +42,35 @@ fn test_attribute_stack_map_table() { assert_eq!(method.attributes.len(), 1); assert_eq!(method.attributes.len(), method.attributes_count as usize); - let code = match code_attribute_parser(&method.attributes[0].info) { - Ok((_, c)) => c, + // The top-level method attribute should be parsed as Code + let code = match method.attributes[0].info_parsed { + Some(AttributeInfoVariant::Code(ref code)) => code, _ => panic!("Could not get code attribute"), }; - let mut stack_map_table_attr_index = 0; - println!("Code Attrs:"); - for (idx, code_attr) in code.attributes.iter().enumerate() { - println!("\t[{}] = {:?}", idx, code_attr); - let AttributeInfo { - ref attribute_name_index, - attribute_length: _, - info: _, - } = *code_attr; - if attribute_name_index == &stack_map_table_index { - stack_map_table_attr_index = idx; - } - } + // Sub-attributes inside CodeAttribute now have interpret_inner called + // automatically, so info_parsed is populated. + let smt_attr = code + .attributes + .iter() + .find(|a| a.attribute_name_index == stack_map_table_index) + .expect("StackMapTable attribute not found"); - let attribute_info_bytes = &code.attributes[stack_map_table_attr_index].info; - let p = stack_map_table_attribute_parser(attribute_info_bytes); - match p { - Ok((data_rem, a)) => { - // We should have used all the data in the stack map attribute - assert!(data_rem.is_empty()); + let smt = match smt_attr.info_parsed { + Some(AttributeInfoVariant::StackMapTable(ref smt)) => smt, + _ => panic!("StackMapTable sub-attribute was not parsed via interpret_inner"), + }; - assert_eq!(a.entries.len(), a.number_of_entries as usize); - assert_eq!(a.entries.len(), 2); + assert_eq!(smt.entries.len(), smt.number_of_entries as usize); + assert_eq!(smt.entries.len(), 2); - use classfile_parser::attribute_info::StackMapFrame::*; - match a.entries[0] { - SameFrame { .. } => {} - _ => panic!("unexpected frame type for frame 0"), - }; - match a.entries[1] { - SameLocals1StackItemFrame { .. } => {} - _ => panic!("unexpected frame type for frame 1: {:?}", &a.entries[1]), - }; - } - _ => panic!("failed to parse StackMapTable"), + match smt.entries[0].inner { + StackMapFrameInner::SameFrame { .. } => {} + _ => panic!("unexpected frame type for frame 0"), + }; + match smt.entries[1].inner { + StackMapFrameInner::SameLocals1StackItemFrame { .. } => {} + _ => panic!("unexpected frame type for frame 1: {:?}", &smt.entries[1]), }; } _ => panic!("not a class file"), diff --git a/tests/classfile.rs b/tests/classfile.rs index e63b94f..a60b14a 100644 --- a/tests/classfile.rs +++ b/tests/classfile.rs @@ -1,15 +1,23 @@ extern crate classfile_parser; -extern crate nom; -use classfile_parser::class_parser; +use binrw::BinWrite; +use binrw::prelude::*; +use classfile_parser::ClassFile; +use classfile_parser::attribute_info::AttributeInfoVariant; use classfile_parser::constant_info::ConstantInfo; +use std::fs::File; +use std::io::Cursor; +use std::io::prelude::*; #[test] fn test_valid_class() { - let valid_class = include_bytes!("../java-assets/compiled-classes/BasicClass.class"); - let res = class_parser(valid_class); + let mut contents: Vec = Vec::new(); + let mut valid_class = File::open("java-assets/compiled-classes/BasicClass.class").unwrap(); + valid_class.read_to_end(&mut contents).unwrap(); + let res = ClassFile::read(&mut Cursor::new(&mut contents)); + dbg!(&res); match res { - Result::Ok((_, c)) => { + Result::Ok(c) => { println!( "Valid class file, version {},{} const_pool({}), this=const[{}], super=const[{}], interfaces({}), fields({}), methods({}), attributes({}), access({:?})", c.major_version, @@ -29,10 +37,10 @@ fn test_valid_class() { println!("Constant pool:"); for (const_index, const_item) in c.const_pool.iter().enumerate() { println!("\t[{}] = {:?}", (const_index + 1), const_item); - if let ConstantInfo::Utf8(ref c) = *const_item - && c.utf8_string.to_string() == "Code" - { - code_const_index = (const_index + 1) as u16; + if let ConstantInfo::Utf8(ref c) = *const_item { + if c.utf8_string == "Code" { + code_const_index = (const_index + 1) as u16; + } } } println!("Code index = {}", code_const_index); @@ -69,10 +77,8 @@ fn test_valid_class() { for a in &m.attributes { if a.attribute_name_index == code_const_index { println!("\t\tCode attr found, len = {}", a.attribute_length); - let code_result = - classfile_parser::attribute_info::code_attribute_parser(&a.info); - match code_result { - Result::Ok((_, code)) => { + match a.info_parsed.as_ref().unwrap() { + AttributeInfoVariant::Code(code) => { println!("\t\t\tCode! code_length = {}", code.code_length); } _ => panic!("Not a valid code attr?"), @@ -89,78 +95,234 @@ fn test_valid_class() { #[test] fn test_utf_string_constants() { - let valid_class = include_bytes!("../java-assets/compiled-classes/UnicodeStrings.class"); - let res = class_parser(valid_class); + let mut contents: Vec = Vec::new(); + let mut utf8_class = File::open("java-assets/compiled-classes/UnicodeStrings.class").unwrap(); + utf8_class.read_to_end(&mut contents).unwrap(); + let res = ClassFile::read(&mut Cursor::new(contents)); match res { - Result::Ok((_, c)) => { - let mut found_utf_maths_string = false; - let mut found_utf_runes_string = false; - let mut found_utf_braille_string = false; - let mut found_utf_modified_string = false; - let mut found_utf_unpaired_string = false; + Result::Ok(c) => { + if let ConstantInfo::Utf8(ref con) = c.const_pool[13] { + assert_eq!(con.utf8_string, "2H₂ + O₂ ⇌ 2H₂O, R = 4.7 kΩ, ⌀ 200 mm"); + } + + if let ConstantInfo::Utf8(ref con) = c.const_pool[21] { + assert_eq!( + con.utf8_string, + "ᚻᛖ ᚳᚹᚫᚦ ᚦᚫᛏ ᚻᛖ ᛒᚢᛞᛖ ᚩᚾ ᚦᚫᛗ ᛚᚪᚾᛞᛖ ᚾᚩᚱᚦᚹᛖᚪᚱᛞᚢᛗ ᚹᛁᚦ ᚦᚪ ᚹᛖᛥᚫ" + ); + } + + if let ConstantInfo::Utf8(ref con) = c.const_pool[23] { + assert_eq!(con.utf8_string, "⡌⠁⠧⠑ ⠼⠁⠒ ⡍⠜⠇⠑⠹⠰⠎ ⡣⠕⠌"); + } + + if let ConstantInfo::Utf8(ref con) = c.const_pool[25] { + assert_eq!(con.utf8_string, "\0𠜎"); + } + + if let ConstantInfo::Utf8(ref con) = c.const_pool[27] { + assert_eq!(con.utf8_string, "X���X"); + assert_eq!(con.utf8_string.len(), 11); + } + for (const_index, const_item) in c.const_pool.iter().enumerate() { println!("\t[{}] = {:?}", (const_index + 1), const_item); - if let ConstantInfo::Utf8(ref c) = *const_item { - if c.utf8_string.to_string() == "2H₂ + O₂ ⇌ 2H₂O, R = 4.7 kΩ, ⌀ 200 mm" - { - found_utf_maths_string = true; - } - if c.utf8_string.to_string() - == "ᚻᛖ ᚳᚹᚫᚦ ᚦᚫᛏ ᚻᛖ ᛒᚢᛞᛖ ᚩᚾ ᚦᚫᛗ ᛚᚪᚾᛞᛖ ᚾᚩᚱᚦᚹᛖᚪᚱᛞᚢᛗ ᚹᛁᚦ ᚦᚪ ᚹᛖᛥᚫ" - { - found_utf_runes_string = true; - } - if c.utf8_string.to_string() == "⡌⠁⠧⠑ ⠼⠁⠒ ⡍⠜⠇⠑⠹⠰⠎ ⡣⠕⠌" - { - found_utf_braille_string = true; - } - if c.utf8_string.to_string() == "\0𠜎" { - found_utf_modified_string = true; - } - if c.utf8_string.to_string() == "X���X" && c.utf8_string.len() == 5 { - found_utf_unpaired_string = true; - } - } } - - assert!( - found_utf_maths_string - & found_utf_runes_string - & found_utf_braille_string - & found_utf_modified_string - & found_utf_unpaired_string, - "Failed to find unicode strings" - ); } + _ => panic!("Not a class file"), } } #[test] fn test_malformed_class() { - let malformed_class = include_bytes!("../java-assets/compiled-classes/malformed.class"); - let res = class_parser(malformed_class); - if let Result::Ok((_, _)) = res { + let mut contents: Vec = Vec::new(); + let mut invalid_class = File::open("java-assets/compiled-classes/malformed.class").unwrap(); + invalid_class.read_to_end(&mut contents).unwrap(); + let res = ClassFile::read(&mut Cursor::new(contents)); + if res.is_ok() { panic!("The file is not valid and shouldn't be parsed") }; } -// #[test] -// fn test_constant_utf8() { -// let hello_world_data = &[ -// // 0x01, // tag = 1 -// 0x00, 0x0C, // length = 12 -// 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21 // 'Hello world!' in UTF8 -// ]; -// let res = const_utf8(hello_world_data); - -// match res { -// Result::Ok((_, c)) => -// match c { -// Constant::Utf8(ref s) => -// println!("Valid UTF8 const: {}", s.utf8_string), -// _ => panic!("It's a const, but of what type?") -// }, -// _ => panic!("Not a UTF type const?"), -// }; -// } +#[test] +fn test_round_trip() { + let mut original_bytes: Vec = Vec::new(); + let mut class_file = File::open("java-assets/compiled-classes/BasicClass.class").unwrap(); + class_file.read_to_end(&mut original_bytes).unwrap(); + + let parsed = ClassFile::read(&mut Cursor::new(&original_bytes)).expect("failed to parse class"); + + let mut written_bytes = Cursor::new(Vec::new()); + parsed + .write(&mut written_bytes) + .expect("failed to write class"); + let written_bytes = written_bytes.into_inner(); + + assert_eq!( + original_bytes.len(), + written_bytes.len(), + "written class file has different length: original={}, written={}", + original_bytes.len(), + written_bytes.len() + ); + assert_eq!( + original_bytes, written_bytes, + "written class file bytes differ from original" + ); +} + +/// Verify that sync_from_parsed() on unmodified Code attributes produces identical bytes. +#[test] +fn test_sync_from_parsed_idempotent() { + // Note: UnicodeStrings excluded — it contains invalid UTF-8 (unpaired surrogates) + // that get normalized to U+FFFD during parsing, so round-trip is not byte-identical. + for class_name in &["BasicClass", "Factorial", "HelloWorld", "Instructions"] { + let path = format!("java-assets/compiled-classes/{}.class", class_name); + let mut original_bytes = Vec::new(); + File::open(&path) + .unwrap_or_else(|_| panic!("failed to open {}", path)) + .read_to_end(&mut original_bytes) + .unwrap(); + + let mut class_file = + ClassFile::read(&mut Cursor::new(&original_bytes)).expect("failed to parse"); + + // Sync all Code attributes without modifying them + for method in &mut class_file.methods { + for attr in &mut method.attributes { + if matches!(attr.info_parsed, Some(AttributeInfoVariant::Code(_))) { + attr.sync_from_parsed().expect("sync_from_parsed failed"); + } + } + } + + let mut written_bytes = Cursor::new(Vec::new()); + class_file + .write(&mut written_bytes) + .expect("failed to write"); + let written_bytes = written_bytes.into_inner(); + + assert_eq!( + original_bytes, written_bytes, + "{}: sync_from_parsed on unmodified Code changed the output", + class_name + ); + } +} + +/// Verify that modifying an instruction survives write → re-read. +#[test] +fn test_mutation_round_trip_instruction() { + let mut original_bytes = Vec::new(); + File::open("java-assets/compiled-classes/BasicClass.class") + .unwrap() + .read_to_end(&mut original_bytes) + .unwrap(); + + let mut class_file = + ClassFile::read(&mut Cursor::new(&original_bytes)).expect("failed to parse"); + + // Find first method with a Code attribute containing Aload0 + let mut found = false; + for method in &mut class_file.methods { + for attr in &mut method.attributes { + if let Some(AttributeInfoVariant::Code(ref mut code)) = attr.info_parsed { + for instr in &mut code.code { + if matches!(instr, classfile_parser::code_attribute::Instruction::Aload0) { + *instr = classfile_parser::code_attribute::Instruction::Aload1; + found = true; + break; + } + } + if found { + attr.sync_from_parsed().unwrap(); + break; + } + } + } + if found { + break; + } + } + assert!(found, "could not find Aload0 in BasicClass"); + + // Write and re-read + let mut out = Cursor::new(Vec::new()); + class_file.write(&mut out).expect("failed to write"); + let written = out.into_inner(); + + let reparsed = ClassFile::read(&mut Cursor::new(&written)).expect("failed to re-parse"); + + // Verify the modification survived + let mut verified = false; + for method in &reparsed.methods { + for attr in &method.attributes { + if let Some(AttributeInfoVariant::Code(ref code)) = attr.info_parsed { + for instr in &code.code { + if matches!(instr, classfile_parser::code_attribute::Instruction::Aload1) { + verified = true; + break; + } + } + } + } + } + assert!(verified, "Aload1 not found after round-trip"); +} + +/// Verify constant pool modification survives write → re-read. +#[test] +fn test_mutation_round_trip_constant_pool() { + let mut original_bytes = Vec::new(); + File::open("java-assets/compiled-classes/BasicClass.class") + .unwrap() + .read_to_end(&mut original_bytes) + .unwrap(); + + let mut class_file = + ClassFile::read(&mut Cursor::new(&original_bytes)).expect("failed to parse"); + + // Modify a UTF-8 string in the constant pool + let mut modified_value = None; + for entry in &mut class_file.const_pool { + if let ConstantInfo::Utf8(utf8) = entry { + if utf8.utf8_string == "Code" { + // Don't modify "Code" — it's used for attribute resolution. + continue; + } + if utf8.utf8_string.len() > 2 { + modified_value = Some(utf8.utf8_string.clone()); + utf8.utf8_string = "MODIFIED".to_string(); + break; + } + } + } + let original_value = modified_value.expect("no suitable UTF-8 constant found"); + + let mut out = Cursor::new(Vec::new()); + class_file.write(&mut out).expect("failed to write"); + let written = out.into_inner(); + + let reparsed = ClassFile::read(&mut Cursor::new(&written)).expect("failed to re-parse"); + + // Verify the modification survived and original value is gone + let mut found_modified = false; + let mut found_original = false; + for entry in &reparsed.const_pool { + if let ConstantInfo::Utf8(utf8) = entry { + if utf8.utf8_string == "MODIFIED" { + found_modified = true; + } + if utf8.utf8_string == original_value { + found_original = true; + } + } + } + assert!(found_modified, "'MODIFIED' not found after round-trip"); + assert!( + !found_original, + "original value '{}' still present after modification", + original_value + ); +} diff --git a/tests/code_attribute.rs b/tests/code_attribute.rs index c492465..24cf134 100644 --- a/tests/code_attribute.rs +++ b/tests/code_attribute.rs @@ -1,50 +1,57 @@ -//only works for nightly builds at the moment -//#![feature(assert_matches)] - extern crate classfile_parser; -//use std::assert_matches::assert_matches; +use std::io::Cursor; +use assert_matches::assert_matches; +use binrw::BinRead; +use classfile_parser::ClassFile; use classfile_parser::attribute_info::{ - DefaultAnnotation, ElementValue, InnerClassAccessFlags, TargetInfo, code_attribute_parser, - element_value_parser, enclosing_method_attribute_parser, inner_classes_attribute_parser, - line_number_table_attribute_parser, method_parameters_attribute_parser, - runtime_invisible_annotations_attribute_parser, - runtime_invisible_parameter_annotations_attribute_parser, - runtime_visible_annotations_attribute_parser, - runtime_visible_parameter_annotations_attribute_parser, - runtime_visible_type_annotations_attribute_parser, signature_attribute_parser, - source_debug_extension_parser, + AttributeInfoVariant, ElementValue, InnerClassAccessFlags, LineNumberTableAttribute, TargetInfo, }; -use classfile_parser::class_parser; use classfile_parser::code_attribute::{ - Instruction, LocalVariableTableAttribute, code_parser, instruction_parser, - local_variable_type_table_parser, + Instruction, LocalVariableTableAttribute, LocalVariableTypeTableAttribute, }; -use classfile_parser::constant_info::{ConstantInfo, Utf8Constant}; +use classfile_parser::constant_info::ConstantInfo; use classfile_parser::method_info::MethodAccessFlags; +fn lookup_string(c: &ClassFile, index: u16) -> Option { + match &c.const_pool[(index - 1) as usize] { + classfile_parser::constant_info::ConstantInfo::Utf8(utf8) => { + Some(utf8.utf8_string.to_string()) + } + _ => None, + } +} + #[test] fn test_simple() { - let instruction = &[0x11, 0xff, 0xfe]; + let mut instruction = vec![0x11, 0xff, 0xfe]; assert_eq!( - Ok((&[][..], Instruction::Sipush(-2i16))), - instruction_parser(instruction, 0) + Instruction::Sipush(-2i16), + Instruction::read_be_args( + &mut Cursor::new(&mut instruction), + binrw::args! { address: 0 } + ) + .unwrap() ); } #[test] fn test_wide() { - let instruction = &[0xc4, 0x15, 0xaa, 0xbb]; + let mut instruction = vec![0xc4, 0x15, 0xaa, 0xbb]; assert_eq!( - Ok((&[][..], Instruction::IloadWide(0xaabb))), - instruction_parser(instruction, 0) + Instruction::IloadWide(0xaabb), + Instruction::read_be_args( + &mut Cursor::new(&mut instruction), + binrw::args! { address: 0 } + ) + .unwrap() ); } #[test] fn test_alignment() { - let instructions = vec![ + let mut instructions: Vec<(u32, Vec)> = vec![ ( 3, vec![ @@ -58,69 +65,91 @@ fn test_alignment() { ], ), ]; - let expected = Ok(( - &[][..], - Instruction::Tableswitch { - default: 10, - low: 20, - high: 21, - offsets: vec![30, 31], - }, - )); - for (address, instruction) in instructions { - assert_eq!(expected, instruction_parser(&instruction, address)); + + let expected = Instruction::Tableswitch { + default: 10, + low: 20, + high: 21, + offsets: vec![30, 31], + }; + + for (address, instruction) in &mut instructions { + assert_eq!( + expected, + Instruction::read_be_args( + &mut Cursor::new(instruction), + binrw::args! { address: *address } + ) + .unwrap() + ); } } #[test] fn test_incomplete() { let code = &[0x59, 0x59, 0xc4, 0x15]; // dup, dup, - let expected = Ok(( - &[0xc4, 0x15][..], - vec![(0, Instruction::Dup), (1, Instruction::Dup)], - )); - assert_eq!(expected, code_parser(code)); + let mut c = Cursor::new(code); + + assert_eq!( + Instruction::Dup, + Instruction::read_be_args(&mut c, binrw::args! { address: 0 }).unwrap() + ); + assert_eq!( + Instruction::Dup, + Instruction::read_be_args(&mut c, binrw::args! { address: 0 }).unwrap() + ); + + let next = Instruction::read_be_args(&mut c, binrw::args! { address: 0 }); + if let binrw::Error::NoVariantMatch { pos } = next.unwrap_err() { + assert_eq!(pos, 2); + } } #[test] fn test_class() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Instructions.class"); - let (_, class) = class_parser(class_bytes).unwrap(); - let method_info = &class + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); + let method_info = class .methods .iter() .find(|m| m.access_flags.contains(MethodAccessFlags::STATIC)) .unwrap(); - let (_, code_attribute) = code_attribute_parser(&method_info.attributes[0].info).unwrap(); - - let parsed = code_parser(&code_attribute.code); - - assert!(parsed.is_ok()); - assert_eq!(64, parsed.unwrap().1.len()); -} -fn lookup_string(c: &classfile_parser::ClassFile, index: u16) -> Option { - let con = &c.const_pool[(index - 1) as usize]; - match con { - classfile_parser::constant_info::ConstantInfo::Utf8(utf8) => { - Some(utf8.utf8_string.to_string()) + let code_attr = method_info.attributes.iter().find_map(|attr| { + if let Some(AttributeInfoVariant::Code(code)) = &attr.info_parsed { + Some(code) + } else { + None } - _ => None, - } + }); + + let code = code_attr.expect("Should have found a Code attribute"); + assert_eq!(64, code.code.len()); } #[test] fn method_parameters() { let class_bytes = include_bytes!("../java-assets/compiled-classes/BasicClass.class"); - let (_, class) = class_parser(class_bytes).unwrap(); - let method_info = &class.methods.iter().last().unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); + let method_info = class.methods.iter().last().unwrap(); // The class was not compiled with "javac -parameters" this required being able to find // MethodParameters in the class file, for example: // javac -parameters ./java-assets/src/uk/co/palmr/classfileparser/BasicClass.java -d ./java-assets/compiled-classes ; cp ./java-assets/compiled-classes/uk/co/palmr/classfileparser/BasicClass.class ./java-assets/compiled-classes/BasicClass.class assert_eq!(method_info.attributes.len(), 2); - let (_, method_parameters) = - method_parameters_attribute_parser(&method_info.attributes[1].info).unwrap(); + + let method_parameters = method_info + .attributes + .iter() + .find_map(|attr| { + if let Some(AttributeInfoVariant::MethodParameters(mp)) = &attr.info_parsed { + Some(mp) + } else { + None + } + }) + .expect("Should have found MethodParameters attribute"); + assert_eq!( lookup_string( &class, @@ -140,13 +169,11 @@ fn method_parameters() { #[test] fn inner_classes() { let class_bytes = include_bytes!("../java-assets/compiled-classes/InnerClasses.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); for attr in &class.attributes { - match lookup_string(&class, attr.attribute_name_index) { - Some(x) if x == "InnerClasses" => { - let (_, inner_class_attrs) = inner_classes_attribute_parser(&attr.info).unwrap(); - + match &attr.info_parsed { + Some(AttributeInfoVariant::InnerClasses(inner_class_attrs)) => { assert_eq!(inner_class_attrs.number_of_classes, 4); assert_eq!( @@ -154,7 +181,7 @@ fn inner_classes() { inner_class_attrs.classes.len() as u16 ); - for c in inner_class_attrs.classes { + for c in &inner_class_attrs.classes { dbg!(&class.const_pool[(c.inner_class_info_index - 1) as usize]); // only == 0 when this class is a top-level class or interface, or when it's @@ -179,7 +206,7 @@ fn inner_classes() { } Some(_) => {} None => panic!( - "Could not find attribute name for index {}", + "Could not parse attribute for index {}", attr.attribute_name_index ), } @@ -190,33 +217,19 @@ fn inner_classes() { // test for enclosing method attribute, which only applies to local and anonymous classes fn enclosing_method() { let class_bytes = include_bytes!("../java-assets/compiled-classes/InnerClasses$2.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); for attr in &class.attributes { - match lookup_string(&class, attr.attribute_name_index) { - Some(x) if x == "EnclosingMethod" => { + match &attr.info_parsed { + Some(AttributeInfoVariant::EnclosingMethod(enclosing)) => { assert_eq!(attr.attribute_length, 4); - let (_, inner_class_attrs) = enclosing_method_attribute_parser(&attr.info).unwrap(); - - match &class.const_pool[(inner_class_attrs.class_index - 1) as usize] { + match &class.const_pool[(enclosing.class_index - 1) as usize] { classfile_parser::constant_info::ConstantInfo::Class(class_constant) => { - /* nightly only rn - * use regular asserts + decomposition instead - let _expected = String::from("InnerClasses"); - assert_matches!( - &class.const_pool[(class_constant.name_index - 1) as usize], - ConstantInfo::Utf8(classfile_parser::constant_info::Utf8Constant { - utf8_string: _expected, - }) - ); */ if let ConstantInfo::Utf8(inner_str) = &class.const_pool[(class_constant.name_index - 1) as usize] { - assert_eq!( - inner_str.utf8_string, - binrw::NullWideString::from("InnerClasses") - ); + assert_eq!(inner_str.utf8_string, String::from("InnerClasses")); } dbg!(&class.const_pool[(class_constant.name_index - 1) as usize]); @@ -224,43 +237,14 @@ fn enclosing_method() { _ => panic!("Expected Class constant"), } - match &class.const_pool[(inner_class_attrs.method_index - 1) as usize] { + match &class.const_pool[(enclosing.method_index - 1) as usize] { classfile_parser::constant_info::ConstantInfo::NameAndType( name_and_type_constant, ) => { - /* - let mut _expected = String::from("sayHello"); - assert_matches!( - &class.const_pool[(name_and_type_constant.name_index - 1) as usize], - ConstantInfo::Utf8(classfile_parser::constant_info::Utf8Constant { - utf8_string: _expected, - }) - ); - */ - if let ConstantInfo::Utf8(inner_str) = - &class.const_pool[(name_and_type_constant.name_index - 1) as usize] - { - assert_eq!( - inner_str.utf8_string, - binrw::NullWideString::from("sayHello") - ); - } - dbg!(&class.const_pool[(name_and_type_constant.name_index - 1) as usize]); - - /* - _expected = String::from("()V"); - assert_matches!( - &class.const_pool - [(name_and_type_constant.descriptor_index - 1) as usize], - ConstantInfo::Utf8(classfile_parser::constant_info::Utf8Constant { - utf8_string: _expected, - }) - ); - */ if let ConstantInfo::Utf8(inner_str) = &class.const_pool [(name_and_type_constant.descriptor_index - 1) as usize] { - assert_eq!(inner_str.utf8_string, binrw::NullWideString::from("()V")); + assert_eq!(inner_str.utf8_string, String::from("()V")); } dbg!( &class.const_pool @@ -269,13 +253,10 @@ fn enclosing_method() { } _ => panic!("Expected NameAndType constant"), } - - //uncomment to see dbg output from above - //assert!(false); } Some(_) => {} None => panic!( - "Could not find attribute name for index {}", + "Could not parse attribute for index {}", attr.attribute_name_index ), } @@ -285,20 +266,16 @@ fn enclosing_method() { #[test] fn synthetic_attribute() { let class_bytes = include_bytes!("../java-assets/compiled-classes/InnerClasses$2.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let synthetic_attrs = class .attributes .iter() - .filter( - |attribute_info| match lookup_string(&class, attribute_info.attribute_name_index) { - Some(s) if s == "Synethic" => true, - Some(_) => false, - None => panic!( - "Could not find attribute name for index {}", - attribute_info.attribute_name_index - ), - }, - ) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::Synthetic(_)) + ) + }) .collect::>(); for attr in &synthetic_attrs { @@ -310,30 +287,26 @@ fn synthetic_attribute() { #[test] fn signature_attribute() { let class_bytes = include_bytes!("../java-assets/compiled-classes/BootstrapMethods.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let signature_attrs = class .methods .iter() .flat_map(|method_info| &method_info.attributes) - .filter( - |attribute_info| match lookup_string(&class, attribute_info.attribute_name_index) { - Some(s) if s == "Signature" => { - eprintln!("Got a signature attr!"); - true - } - Some(_) => false, - None => panic!( - "Could not find attribute name for index {}", - attribute_info.attribute_name_index - ), - }, - ) + .filter(|attribute_info| { + if let Some(AttributeInfoVariant::Signature(_)) = &attribute_info.info_parsed { + eprintln!("Got a signature attr!"); + true + } else { + false + } + }) .collect::>(); for attr in &signature_attrs { - let (_, signature_attr) = signature_attribute_parser(&attr.info).unwrap(); - let signature_string = lookup_string(&class, signature_attr.signature_index).unwrap(); - dbg!(signature_string); + if let Some(AttributeInfoVariant::Signature(sig)) = &attr.info_parsed { + let signature_string = lookup_string(&class, sig.signature_index).unwrap(); + dbg!(signature_string); + } } //uncomment to see dbg output from above @@ -344,40 +317,34 @@ fn signature_attribute() { fn local_variable_table() { // The class was not compiled with "javac -g" let class_bytes = include_bytes!("../java-assets/compiled-classes/LocalVariableTable.class"); - let (_, class) = class_parser(class_bytes).unwrap(); - let method_info = &class.methods.iter().last().unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); + let method_info = class.methods.iter().last().unwrap(); let code_attribute = method_info .attributes .iter() .find_map(|attribute_info| { - match lookup_string(&class, attribute_info.attribute_name_index)?.as_str() { - "Code" => { - classfile_parser::attribute_info::code_attribute_parser(&attribute_info.info) - .ok() - } - _ => None, + if let Some(AttributeInfoVariant::Code(code)) = &attribute_info.info_parsed { + Some(code) + } else { + None } }) - .map(|i| i.1) - .unwrap(); + .expect("Should have found a Code attribute"); + // Code attribute's sub-attributes do NOT have info_parsed populated, so we parse manually let local_variable_table_attribute: LocalVariableTableAttribute = code_attribute .attributes .iter() .find_map(|attribute_info| { match lookup_string(&class, attribute_info.attribute_name_index)?.as_str() { "LocalVariableTable" => { - classfile_parser::code_attribute::local_variable_table_parser( - &attribute_info.info, - ) - .ok() + LocalVariableTableAttribute::read(&mut Cursor::new(&attribute_info.info)).ok() } _ => None, } }) - .map(|a| a.1) - .unwrap(); + .expect("Should have found a LocalVariableTable attribute"); let types: Vec = local_variable_table_attribute .items @@ -399,52 +366,41 @@ fn local_variable_table() { #[test] fn runtime_visible_annotations() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Annotations.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let runtime_visible_annotations_attribute = class .methods .iter() .flat_map(|m| &m.attributes) - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "RuntimeVisibleAnnotations")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::RuntimeVisibleAnnotations(_)) + ) + }) .collect::>(); assert_eq!(runtime_visible_annotations_attribute.len(), 1); let f = runtime_visible_annotations_attribute.first().unwrap(); - let visible_annotations = runtime_visible_annotations_attribute_parser(&f.info); - let inner = &visible_annotations.unwrap(); - assert!(&inner.0.is_empty()); - - /* - let should_be = RuntimeVisibleTypeAnnotationsAttribute { - num_annotations: 1, - annotations: vec![RuntimeAnnotation { - type_index: 30, - num_element_value_pairs: 1, - element_value_pairs: vec![ElementValuePair { - element_name_index: 31, - value: ElementValue::ConstValueIndex { - tag: 's', - value: 32, - }, - }], - }], + let inner = match &f.info_parsed { + Some(AttributeInfoVariant::RuntimeVisibleAnnotations(rva)) => rva, + _ => panic!("Expected RuntimeVisibleAnnotations"), }; - */ - assert_eq!(inner.1.num_annotations, 1); - assert_eq!(inner.1.annotations.len(), 1); - assert_eq!(inner.1.annotations[0].type_index, 46); - assert_eq!(inner.1.annotations[0].num_element_value_pairs, 1); - assert_eq!(inner.1.annotations[0].element_value_pairs.len(), 1); + assert_eq!(inner.num_annotations, 1); + assert_eq!(inner.annotations.len(), 1); + assert_eq!(inner.annotations[0].type_index, 46); + assert_eq!(inner.annotations[0].num_element_value_pairs, 1); + assert_eq!(inner.annotations[0].element_value_pairs.len(), 1); assert_eq!( - inner.1.annotations[0].element_value_pairs[0].element_name_index, + inner.annotations[0].element_value_pairs[0].element_name_index, 37 ); - match inner.1.annotations[0].element_value_pairs[0].value { - ElementValue::ConstValueIndex { tag, value } => { - assert_eq!(tag, 's'); - assert_eq!(value, 47); + match &inner.annotations[0].element_value_pairs[0].value { + ElementValue::ConstValueIndex(cv) => { + assert_eq!(cv.tag, 's'); + assert_eq!(cv.value, 47); } _ => panic!("Expected ConstValueIndex"), } @@ -453,52 +409,41 @@ fn runtime_visible_annotations() { #[test] fn runtime_invisible_annotations() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Annotations.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let runtime_invisible_annotations_attribute = class .methods .iter() .flat_map(|m| &m.attributes) - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "RuntimeInvisibleAnnotations")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::RuntimeInvisibleAnnotations(_)) + ) + }) .collect::>(); assert_eq!(runtime_invisible_annotations_attribute.len(), 1); let f = runtime_invisible_annotations_attribute.first().unwrap(); - let invisible_annotations = runtime_invisible_annotations_attribute_parser(&f.info); - let inner = &invisible_annotations.unwrap(); - assert!(&inner.0.is_empty()); - - /* - let should_be = RuntimeVisibleTypeAnnotationsAttribute { - num_annotations: 1, - annotations: vec![RuntimeAnnotation { - type_index: 30, - num_element_value_pairs: 1, - element_value_pairs: vec![ElementValuePair { - element_name_index: 31, - value: ElementValue::ConstValueIndex { - tag: 's', - value: 32, - }, - }], - }], + let inner = match &f.info_parsed { + Some(AttributeInfoVariant::RuntimeInvisibleAnnotations(ria)) => ria, + _ => panic!("Expected RuntimeInvisibleAnnotations"), }; - */ - assert_eq!(inner.1.num_annotations, 1); - assert_eq!(inner.1.annotations.len(), 1); - assert_eq!(inner.1.annotations[0].type_index, 49); - assert_eq!(inner.1.annotations[0].num_element_value_pairs, 1); - assert_eq!(inner.1.annotations[0].element_value_pairs.len(), 1); + assert_eq!(inner.num_annotations, 1); + assert_eq!(inner.annotations.len(), 1); + assert_eq!(inner.annotations[0].type_index, 49); + assert_eq!(inner.annotations[0].num_element_value_pairs, 1); + assert_eq!(inner.annotations[0].element_value_pairs.len(), 1); assert_eq!( - inner.1.annotations[0].element_value_pairs[0].element_name_index, + inner.annotations[0].element_value_pairs[0].element_name_index, 37 ); - match inner.1.annotations[0].element_value_pairs[0].value { - ElementValue::ConstValueIndex { tag, value } => { - assert_eq!(tag, 's'); - assert_eq!(value, 50); + match &inner.annotations[0].element_value_pairs[0].value { + ElementValue::ConstValueIndex(cv) => { + assert_eq!(cv.tag, 's'); + assert_eq!(cv.value, 50); } _ => panic!("Expected ConstValueIndex"), } @@ -507,34 +452,40 @@ fn runtime_invisible_annotations() { #[test] fn runtime_visible_parameter_annotations() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Annotations.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let runtime_visible_annotations_attribute = class .methods .iter() .flat_map(|m| &m.attributes) - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "RuntimeVisibleParameterAnnotations")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::RuntimeVisibleParameterAnnotations(_)) + ) + }) .collect::>(); assert_eq!(runtime_visible_annotations_attribute.len(), 1); let f = runtime_visible_annotations_attribute.first().unwrap(); - let visible_annotations = runtime_visible_parameter_annotations_attribute_parser(&f.info); - let inner = &visible_annotations.unwrap(); - assert!(&inner.0.is_empty()); + let inner = match &f.info_parsed { + Some(AttributeInfoVariant::RuntimeVisibleParameterAnnotations(rvpa)) => rvpa, + _ => panic!("Expected RuntimeVisibleParameterAnnotations"), + }; - assert_eq!(inner.1.num_parameters, 2); - assert_eq!(inner.1.parameter_annotations.len(), 2); - assert_eq!(inner.1.parameter_annotations[0].num_annotations, 1); - assert_eq!(inner.1.parameter_annotations[0].annotations.len(), 1); + assert_eq!(inner.num_parameters, 2); + assert_eq!(inner.parameter_annotations.len(), 2); + assert_eq!(inner.parameter_annotations[0].num_annotations, 1); + assert_eq!(inner.parameter_annotations[0].annotations.len(), 1); - match inner.1.parameter_annotations[0].annotations[0].element_value_pairs[0].value { - ElementValue::ConstValueIndex { tag, value } => { - assert_eq!(tag, 's'); - assert_eq!(value, 53); + match &inner.parameter_annotations[0].annotations[0].element_value_pairs[0].value { + ElementValue::ConstValueIndex(cv) => { + assert_eq!(cv.tag, 's'); + assert_eq!(cv.value, 53); } _ => panic!( "expected ConstValueIndex, got {:?}", - inner.1.parameter_annotations[0].annotations[0].element_value_pairs[0].value + inner.parameter_annotations[0].annotations[0].element_value_pairs[0].value ), } } @@ -542,34 +493,42 @@ fn runtime_visible_parameter_annotations() { #[test] fn runtime_invisible_parameter_annotations() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Annotations.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let runtime_invisible_annotations_attribute = class .methods .iter() .flat_map(|m| &m.attributes) - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "RuntimeInvisibleParameterAnnotations")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::RuntimeInvisibleParameterAnnotations( + _ + )) + ) + }) .collect::>(); assert_eq!(runtime_invisible_annotations_attribute.len(), 1); let f = runtime_invisible_annotations_attribute.first().unwrap(); - let invisible_annotations = runtime_invisible_parameter_annotations_attribute_parser(&f.info); - let inner = &invisible_annotations.unwrap(); - assert!(&inner.0.is_empty()); + let inner = match &f.info_parsed { + Some(AttributeInfoVariant::RuntimeInvisibleParameterAnnotations(ripa)) => ripa, + _ => panic!("Expected RuntimeInvisibleParameterAnnotations"), + }; - assert_eq!(inner.1.num_parameters, 2); - assert_eq!(inner.1.parameter_annotations.len(), 2); - assert_eq!(inner.1.parameter_annotations[1].num_annotations, 1); - assert_eq!(inner.1.parameter_annotations[1].annotations.len(), 1); + assert_eq!(inner.num_parameters, 2); + assert_eq!(inner.parameter_annotations.len(), 2); + assert_eq!(inner.parameter_annotations[1].num_annotations, 1); + assert_eq!(inner.parameter_annotations[1].annotations.len(), 1); - match inner.1.parameter_annotations[1].annotations[0].element_value_pairs[0].value { - ElementValue::ConstValueIndex { tag, value } => { - assert_eq!(tag, 's'); - assert_eq!(value, 50); + match &inner.parameter_annotations[1].annotations[0].element_value_pairs[0].value { + ElementValue::ConstValueIndex(cv) => { + assert_eq!(cv.tag, 's'); + assert_eq!(cv.value, 50); } _ => panic!( "expected ConstValueIndex, got {:?}", - inner.1.parameter_annotations[0].annotations[0].element_value_pairs[0].value + inner.parameter_annotations[0].annotations[0].element_value_pairs[0].value ), } } @@ -577,112 +536,126 @@ fn runtime_invisible_parameter_annotations() { #[test] fn runtime_visible_type_annotations() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Annotations.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let runtime_visible_type_annotations_attribute = class .fields .iter() .flat_map(|f| &f.attributes) - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "RuntimeVisibleTypeAnnotations")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::RuntimeVisibleTypeAnnotations(_)) + ) + }) .collect::>(); assert_eq!(runtime_visible_type_annotations_attribute.len(), 1); let f = runtime_visible_type_annotations_attribute.first().unwrap(); - let visible_annotations = runtime_visible_type_annotations_attribute_parser(&f.info); - let inner = &visible_annotations.unwrap(); - assert_eq!(inner.1.num_annotations, 1); - assert_eq!(inner.1.type_annotations.len(), 1); - assert_eq!(inner.1.type_annotations[0].target_type, 19); - //assert_matches!(inner.1.type_annotations[0].target_info, TargetInfo::Empty); - assert_eq!(inner.1.type_annotations[0].target_path.path_length, 0); - assert_eq!(inner.1.type_annotations[0].target_path.paths.len(), 0); - assert_eq!(inner.1.type_annotations[0].type_index, 36); - assert_eq!(inner.1.type_annotations[0].num_element_value_pairs, 1); - assert_eq!(inner.1.type_annotations[0].element_value_pairs.len(), 1); + let inner = match &f.info_parsed { + Some(AttributeInfoVariant::RuntimeVisibleTypeAnnotations(rvta)) => rvta, + _ => panic!("Expected RuntimeVisibleTypeAnnotations"), + }; + + assert_eq!(inner.num_annotations, 1); + assert_eq!(inner.type_annotations.len(), 1); + assert_eq!(inner.type_annotations[0].target_type, 19); + assert_matches!(inner.type_annotations[0].target_info, TargetInfo::Empty); + assert_eq!(inner.type_annotations[0].target_path.path_length, 0); + assert_eq!(inner.type_annotations[0].target_path.paths.len(), 0); + assert_eq!(inner.type_annotations[0].type_index, 36); + assert_eq!(inner.type_annotations[0].num_element_value_pairs, 1); + assert_eq!(inner.type_annotations[0].element_value_pairs.len(), 1); assert_eq!( - inner.1.type_annotations[0].element_value_pairs[0].element_name_index, + inner.type_annotations[0].element_value_pairs[0].element_name_index, 37 ); - /* - assert_matches!( - inner.1.type_annotations[0].element_value_pairs[0].value, - ElementValue::ConstValueIndex { - tag: 's', - value: 38 + match &inner.type_annotations[0].element_value_pairs[0].value { + ElementValue::ConstValueIndex(cv) => { + assert_eq!(cv.tag, 's'); + assert_eq!(cv.value, 38); } - ); - */ + _ => panic!("Expected ConstValueIndex"), + } } #[test] fn runtime_invisible_type_annotations() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Annotations.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let runtime_invisible_type_annotations_attribute = class .fields .iter() .flat_map(|f| &f.attributes) - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "RuntimeInvisibleTypeAnnotations")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::RuntimeInvisibleTypeAnnotations(_)) + ) + }) .collect::>(); assert_eq!(runtime_invisible_type_annotations_attribute.len(), 1); let f = runtime_invisible_type_annotations_attribute .first() .unwrap(); - let invisible_annotations = runtime_visible_type_annotations_attribute_parser(&f.info); - let inner = &invisible_annotations.unwrap(); - assert_eq!(inner.1.num_annotations, 1); - assert_eq!(inner.1.type_annotations.len(), 1); - assert_eq!(inner.1.type_annotations[0].target_type, 19); - //assert_matches!(inner.1.type_annotations[0].target_info, TargetInfo::Empty); - assert_eq!(inner.1.type_annotations[0].target_path.path_length, 0); - assert_eq!(inner.1.type_annotations[0].target_path.paths.len(), 0); - assert_eq!(inner.1.type_annotations[0].type_index, 41); - assert_eq!(inner.1.type_annotations[0].num_element_value_pairs, 1); - assert_eq!(inner.1.type_annotations[0].element_value_pairs.len(), 1); + let inner = match &f.info_parsed { + Some(AttributeInfoVariant::RuntimeInvisibleTypeAnnotations(rita)) => rita, + _ => panic!("Expected RuntimeInvisibleTypeAnnotations"), + }; + + assert_eq!(inner.num_annotations, 1); + assert_eq!(inner.type_annotations.len(), 1); + assert_eq!(inner.type_annotations[0].target_type, 19); + assert_matches!(inner.type_annotations[0].target_info, TargetInfo::Empty); + assert_eq!(inner.type_annotations[0].target_path.path_length, 0); + assert_eq!(inner.type_annotations[0].target_path.paths.len(), 0); + assert_eq!(inner.type_annotations[0].type_index, 41); + assert_eq!(inner.type_annotations[0].num_element_value_pairs, 1); + assert_eq!(inner.type_annotations[0].element_value_pairs.len(), 1); assert_eq!( - inner.1.type_annotations[0].element_value_pairs[0].element_name_index, + inner.type_annotations[0].element_value_pairs[0].element_name_index, 37 ); - /* - assert_matches!( - inner.1.type_annotations[0].element_value_pairs[0].value, - ElementValue::ConstValueIndex { - tag: 's', - value: 42 + match &inner.type_annotations[0].element_value_pairs[0].value { + ElementValue::ConstValueIndex(cv) => { + assert_eq!(cv.tag, 's'); + assert_eq!(cv.value, 42); } - ); - */ + _ => panic!("Expected ConstValueIndex"), + } } #[test] fn default_annotation_value() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Annotations$VisibleAtRuntime.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let default_annotation_attributes = class .methods .iter() .flat_map(|m| &m.attributes) - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "AnnotationDefault")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::AnnotationDefault(_)) + ) + }) .collect::>(); assert_eq!(default_annotation_attributes.len(), 1); let f = default_annotation_attributes.first().unwrap(); - let default_annotation = element_value_parser(&f.info); - let inner: DefaultAnnotation = default_annotation.unwrap().1 as DefaultAnnotation; - /* - assert_matches!( - inner, - ElementValue::ConstValueIndex { - tag: 's', - value: 10 + let inner = match &f.info_parsed { + Some(AttributeInfoVariant::AnnotationDefault(ad)) => ad, + _ => panic!("Expected AnnotationDefault"), + }; + + match inner { + ElementValue::ConstValueIndex(cv) => { + assert_eq!(cv.tag, 's'); + assert_eq!(cv.value, 10); } - ); - */ - if let ElementValue::ConstValueIndex { tag, value } = inner { - assert_eq!(tag, 's'); - assert_eq!(value, 10); + _ => panic!("Expected ConstValueIndex"), } } @@ -692,44 +665,39 @@ fn default_annotation_value() { // Virtual Machine", so I will leave this test to be better developed when example // use cases are found. // #[test] +#[allow(dead_code)] fn source_debug_extension() { let class_bytes = include_bytes!("../java-assets/compiled-classes/BasicClass.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let source_debug_extension_attribute = class .attributes .iter() - .filter(|attribute_info| matches!(lookup_string(&class, attribute_info.attribute_name_index), Some(s) if s == "SourceDebugExtension")) + .filter(|attribute_info| { + matches!( + &attribute_info.info_parsed, + Some(AttributeInfoVariant::SourceDebugExtension(_)) + ) + }) .collect::>(); assert_eq!(source_debug_extension_attribute.len(), 1); - let f = source_debug_extension_attribute.first().unwrap(); - - let default_annotation = source_debug_extension_parser(&f.info); - let inner = &default_annotation.unwrap(); - dbg!(inner); } #[test] fn source_file() { let class_bytes = include_bytes!("../java-assets/compiled-classes/BasicClass.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let source = class .attributes .iter() .find_map(|attribute_info| { - match lookup_string(&class, attribute_info.attribute_name_index)?.as_str() { - "SourceFile" => classfile_parser::attribute_info::sourcefile_attribute_parser( - &attribute_info.info, - ) - .ok(), - o => { - dbg!(o); - None - } + if let Some(AttributeInfoVariant::SourceFile(sf)) = &attribute_info.info_parsed { + Some(sf) + } else { + None } }) - .map(|i| i.1) - .unwrap(); + .expect("Should have found a SourceFile attribute"); let s = lookup_string(&class, source.sourcefile_index).unwrap(); @@ -739,26 +707,36 @@ fn source_file() { #[test] fn line_number_table() { let class_bytes = include_bytes!("../java-assets/compiled-classes/Instructions.class"); - let (_, class) = class_parser(class_bytes).unwrap(); - let default_annotation_attributes = class + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); + let static_method = class .methods .iter() .find(|m| m.access_flags.contains(MethodAccessFlags::STATIC)) .unwrap(); - let (_, code_attribute) = - code_attribute_parser(&default_annotation_attributes.attributes[0].info).unwrap(); + let code_attribute = static_method + .attributes + .iter() + .find_map(|attr| { + if let Some(AttributeInfoVariant::Code(code)) = &attr.info_parsed { + Some(code) + } else { + None + } + }) + .expect("Should have found a Code attribute"); assert_eq!( code_attribute.attributes.len(), code_attribute.attributes_count as usize ); + // Code attribute's sub-attributes do NOT have info_parsed populated, so we parse manually let line_number_tables = &code_attribute .attributes .iter() .filter(|a| lookup_string(&class, a.attribute_name_index).unwrap() == "LineNumberTable") - .map(|a| line_number_table_attribute_parser(&a.info).unwrap().1) + .map(|a| LineNumberTableAttribute::read(&mut Cursor::new(&a.info)).unwrap()) .collect::>(); assert_eq!(line_number_tables.len(), 1); @@ -772,32 +750,35 @@ fn line_number_table() { fn local_variable_type_table() { // The class was not compiled with "javac -g" let class_bytes = include_bytes!("../java-assets/compiled-classes/LocalVariableTable.class"); - let (_, class) = class_parser(class_bytes).unwrap(); - let method_info = &class.methods.iter().last().unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); + let method_info = class.methods.iter().last().unwrap(); - let local_variable_table_type_attribute = method_info + let code_attribute = method_info .attributes .iter() .find_map(|attribute_info| { - match lookup_string(&class, attribute_info.attribute_name_index)?.as_str() { - "Code" => code_attribute_parser(&attribute_info.info).ok(), - _ => None, + if let Some(AttributeInfoVariant::Code(code)) = &attribute_info.info_parsed { + Some(code) + } else { + None } }) - .map(|i| i.1) - .unwrap() + .expect("Should have found a Code attribute"); + + // Code attribute's sub-attributes do NOT have info_parsed populated, so we parse manually + let local_variable_table_type_attribute = code_attribute .attributes .iter() .find_map(|attribute_info| { match lookup_string(&class, attribute_info.attribute_name_index)?.as_str() { "LocalVariableTypeTable" => { - local_variable_type_table_parser(&attribute_info.info).ok() + LocalVariableTypeTableAttribute::read(&mut Cursor::new(&attribute_info.info)) + .ok() } _ => None, } }) - .map(|a| a.1) - .unwrap(); + .expect("Should have found a LocalVariableTypeTable attribute"); let types: Vec = local_variable_table_type_attribute .local_variable_type_table @@ -815,17 +796,15 @@ fn local_variable_type_table() { #[test] fn deprecated() { let class_bytes = include_bytes!("../java-assets/compiled-classes/DeprecatedAnnotation.class"); - let (_, class) = class_parser(class_bytes).unwrap(); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())).unwrap(); let deprecated_class_attribute = &class .attributes .iter() .filter(|attribute_info| { matches!( - lookup_string(&class, attribute_info.attribute_name_index) - .unwrap() - .as_str(), - "Deprecated" + &attribute_info.info_parsed, + Some(AttributeInfoVariant::Deprecated(_)) ) }) .collect::>(); @@ -838,10 +817,8 @@ fn deprecated() { .flat_map(|m| &m.attributes) .filter(|attribute_info| { matches!( - lookup_string(&class, attribute_info.attribute_name_index) - .unwrap() - .as_str(), - "Deprecated" + &attribute_info.info_parsed, + Some(AttributeInfoVariant::Deprecated(_)) ) }) .collect::>(); @@ -854,10 +831,8 @@ fn deprecated() { .flat_map(|f| &f.attributes) .filter(|attribute_info| { matches!( - lookup_string(&class, attribute_info.attribute_name_index) - .unwrap() - .as_str(), - "Deprecated" + &attribute_info.info_parsed, + Some(AttributeInfoVariant::Deprecated(_)) ) }) .collect::>(); diff --git a/tests/compiler/e2e.rs b/tests/compiler/e2e.rs new file mode 100644 index 0000000..36a1803 --- /dev/null +++ b/tests/compiler/e2e.rs @@ -0,0 +1,1568 @@ +use super::*; + +// --- Codegen unit tests --- + +#[test] +fn test_codegen_return_42() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_, _, mut class_file) = compile_and_load( + "codegen_ret42", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let stmts = parse_method_body("{ return; }").unwrap(); + let generated = + generate_bytecode(&stmts, &mut class_file, true, "([Ljava/lang/String;)V").unwrap(); + + // Should contain Return instruction + assert!( + generated + .instructions + .iter() + .any(|i| matches!(i, Instruction::Return)) + ); + assert!(generated.max_stack >= 1); + assert!(generated.max_locals >= 1); +} + +// --- Basic E2E tests --- + +#[test] +fn test_compile_e2e_hello_compiled() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = + compile_and_load("e2e_hello", "java-assets/src/HelloWorld.java", "HelloWorld"); + + compile_method_body( + r#"{ System.out.println("Compiled!"); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "Compiled!", + "expected 'Compiled!' but got: {}", + output + ); +} + +#[test] +fn test_compile_e2e_return_value() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_retval", + "java-assets/src/SimpleMath.java", + "SimpleMath", + ); + + // Replace intMath to return a different formula: a constant + compile_method_body( + r#"{ System.out.println(99); }"#, + &mut class_file, + "intMath", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "SimpleMath"); + assert!(output.contains("99"), "expected 99 in output: {}", output); +} + +#[test] +fn test_compile_e2e_if_else() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_ifelse", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 10; + if (x > 5) { + System.out.println("big"); + } else { + System.out.println("small"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "big", "expected 'big' but got: {}", output); +} + +#[test] +fn test_compile_e2e_while_loop() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = + compile_and_load("e2e_while", "java-assets/src/HelloWorld.java", "HelloWorld"); + + compile_method_body( + r#"{ + int sum = 0; + int i = 1; + while (i <= 10) { + sum = sum + i; + i = i + 1; + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "55", "expected '55' but got: {}", output); +} + +#[test] +fn test_compile_e2e_for_loop() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = + compile_and_load("e2e_for", "java-assets/src/HelloWorld.java", "HelloWorld"); + + compile_method_body( + r#"{ + int sum = 0; + for (int i = 1; i <= 5; i = i + 1) { + sum = sum + i; + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "15", "expected '15' but got: {}", output); +} + +#[test] +fn test_compile_e2e_arithmetic() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = + compile_and_load("e2e_arith", "java-assets/src/HelloWorld.java", "HelloWorld"); + + compile_method_body( + r#"{ + int a = 10; + int b = 3; + int c = a * b + a / b - a % b; + System.out.println(c); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // 10*3 + 10/3 - 10%3 = 30 + 3 - 1 = 32 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "32", "expected '32' but got: {}", output); +} + +#[test] +fn test_compile_e2e_nested_if() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_nested_if", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 15; + if (x > 10) { + if (x > 20) { + System.out.println("very big"); + } else { + System.out.println("medium"); + } + } else { + System.out.println("small"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "medium", "expected 'medium' but got: {}", output); +} + +// --- Switch E2E tests --- + +#[test] +fn test_compile_e2e_switch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_switch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 2; + switch (x) { + case 1: + System.out.println("one"); + break; + case 2: + System.out.println("two"); + break; + case 3: + System.out.println("three"); + break; + default: + System.out.println("other"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "two", "expected 'two' but got: {}", output); +} + +#[test] +fn test_compile_e2e_switch_default() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_switch_default", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 99; + switch (x) { + case 1: + System.out.println("one"); + break; + case 2: + System.out.println("two"); + break; + default: + System.out.println("default"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "default", "expected 'default' but got: {}", output); +} + +// --- Try-catch E2E tests --- + +#[test] +fn test_compile_e2e_try_catch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_try_catch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + try { + throw new RuntimeException("boom"); + } catch (RuntimeException e) { + System.out.println("caught"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "caught", "expected 'caught' but got: {}", output); +} + +#[test] +fn test_compile_e2e_try_finally() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_try_finally", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + try { + System.out.println("try"); + } finally { + System.out.println("finally"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "try\nfinally", + "expected 'try\\nfinally' but got: {}", + output + ); +} + +// --- StackMapTable generation tests --- + +#[test] +fn test_compile_e2e_stackmap_if_else() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_smt_ifelse", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 10; + if (x > 5) { + System.out.println("big"); + } else { + System.out.println("small"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "big", "expected 'big' but got: {}", output); +} + +#[test] +fn test_compile_e2e_stackmap_while() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_smt_while", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int sum = 0; + int i = 1; + while (i <= 10) { + sum = sum + i; + i = i + 1; + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "55", "expected '55' but got: {}", output); +} + +#[test] +fn test_compile_e2e_stackmap_try_catch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_smt_trycatch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + try { + throw new RuntimeException("boom"); + } catch (RuntimeException e) { + System.out.println("caught"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "caught", "expected 'caught' but got: {}", output); +} + +// --- Typed arithmetic tests --- + +#[test] +fn test_compile_e2e_long_arithmetic() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_long_arith", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + long a = 1000000000L; + long b = 2000000000L; + long c = a + b; + System.out.println(c); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "3000000000"); +} + +#[test] +fn test_compile_e2e_double_arithmetic() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_double_arith", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + double a = 1.5; + double b = 2.5; + System.out.println(a + b); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "4.0"); +} + +#[test] +fn test_compile_e2e_float_arithmetic() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_float_arith", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + float a = 1.5f; + float b = 2.5f; + System.out.println(a * b); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "3.75"); +} + +#[test] +fn test_compile_e2e_widening() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_widening", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 10; + long b = 20L; + long c = a + b; + System.out.println(c); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "30"); +} + +#[test] +fn test_compile_e2e_long_comparison() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_long_cmp", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + long x = 5L; + if (x > 3L) { + System.out.println("yes"); + } else { + System.out.println("no"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "yes"); +} + +#[test] +fn test_compile_e2e_cast_types() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_cast_types", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + double d = 3.14; + int i = (int) d; + System.out.println(i); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "3"); +} + +#[test] +fn test_compile_e2e_unary_neg_long() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_neg_long", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + long x = 10L; + System.out.println(-x); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "-10"); +} + +// --- String concatenation tests --- + +#[test] +fn test_compile_e2e_string_concat() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_str_concat", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + System.out.println("hello" + " " + "world"); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "hello world"); +} + +#[test] +fn test_compile_e2e_string_concat_int() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_str_concat_int", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int n = 42; + String s = "n=" + n; + System.out.println(s); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "n=42"); +} + +#[test] +fn test_compile_e2e_string_concat_chain() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_str_concat_chain", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + String s = "a" + "b" + "c" + "d"; + System.out.println(s); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "abcd"); +} + +// --- Typed array tests --- + +#[test] +fn test_compile_e2e_typed_array_long() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_arr_long", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + long[] arr = new long[2]; + arr[0] = 100L; + arr[1] = 200L; + System.out.println(arr[0] + arr[1]); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "300"); +} + +#[test] +fn test_compile_e2e_typed_array_double() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_arr_double", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + double[] arr = new double[2]; + arr[0] = 1.5; + arr[1] = 2.5; + System.out.println(arr[0] + arr[1]); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "4.0"); +} + +// --- For-each tests --- + +#[test] +fn test_compile_e2e_foreach_array() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_foreach", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[] arr = new int[3]; + arr[0] = 10; + arr[1] = 20; + arr[2] = 30; + int sum = 0; + for (int x : arr) { + sum = sum + x; + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "60"); +} + +// --- P1 E2E tests --- + +#[test] +fn test_compile_e2e_multi_catch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_multi_catch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + try { + throw new RuntimeException("boom"); + } catch (IllegalArgumentException | RuntimeException e) { + System.out.println("caught multi"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "caught multi", + "expected 'caught multi' but got: {}", + output + ); +} + +#[test] +fn test_compile_e2e_synchronized() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_synchronized", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + Object lock = new Object(); + synchronized (lock) { + System.out.println("locked"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "locked", "expected 'locked' but got: {}", output); +} + +// --- P2 E2E tests --- + +#[test] +fn test_compile_e2e_var() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = + compile_and_load("e2e_var", "java-assets/src/HelloWorld.java", "HelloWorld"); + + compile_method_body( + r#"{ + var x = 10; + var s = "hello"; + System.out.println(x); + System.out.println(s); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "10\nhello", + "expected '10\\nhello' but got: {}", + output + ); +} + +#[test] +fn test_compile_e2e_var_long() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_var_long", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + var x = 100L; + System.out.println(x); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "100", "expected '100' but got: {}", output); +} + +#[test] +fn test_compile_e2e_multi_dim_array() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_multi_dim", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[][] arr = new int[2][3]; + arr[0][0] = 42; + arr[1][2] = 99; + System.out.println(arr[0][0]); + System.out.println(arr[1][2]); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "42\n99", "expected '42\\n99' but got: {}", output); +} + +#[test] +fn test_compile_e2e_switch_expr() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_switch_expr", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 2; + int r = switch (x) { + case 1 -> 10; + case 2 -> 20; + case 3 -> 30; + default -> 0; + }; + System.out.println(r); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "20", "expected '20' but got: {}", output); +} + +#[test] +fn test_compile_e2e_switch_expr_default() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_switch_expr_default", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 99; + int r = switch (x) { + case 1 -> 10; + default -> 0; + }; + System.out.println(r); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "0", "expected '0' but got: {}", output); +} + +// --- Additional E2E tests --- + +/// Test: Switch fall-through — case 2 executes case 2 and case 3 bodies (no break between). +#[test] +fn test_compile_e2e_switch_fallthrough() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_switch_fallthrough", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 2; + int acc = 0; + switch (x) { + case 1: + acc = acc + 1; + case 2: + acc = acc + 10; + case 3: + acc = acc + 100; + break; + default: + acc = -1; + } + System.out.println(acc); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "110", "expected '110' but got: {}", output); +} + +/// Test: For-each over an Iterable (java.util.List) via invokeinterface iterator/hasNext/next. +#[test] +fn test_compile_e2e_foreach_list() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_foreach_list", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + java.util.List list = new java.util.ArrayList<>(); + list.add("alpha"); + list.add("beta"); + list.add("gamma"); + for (String s : list) { + System.out.println(s); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "alpha\nbeta\ngamma", + "expected three lines but got: {}", + output + ); +} + +/// Test: Null concatenation — null references stringify to "null" in String concatenation. +#[test] +fn test_compile_e2e_null_concatenation() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_null_concat", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + String s = null; + String result = "value=" + s; + System.out.println(result); + System.out.println("literal=" + null); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "value=null\nliteral=null", + "expected null concat output but got: {}", + output + ); +} + +/// Test: Multi-catch second type — verifies both types in a multi-catch are matched. +#[test] +fn test_compile_e2e_multicatch_second_type() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_multicatch_second", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + try { + throw new NullPointerException("npe"); + } catch (IllegalArgumentException | NullPointerException e) { + System.out.println("caught: " + e.getClass().getSimpleName()); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "caught: NullPointerException", + "expected multi-catch output but got: {}", + output + ); +} + +/// Test: Try-catch in loop with continue — catch handler executes continue to the loop. +#[test] +fn test_compile_e2e_trycatch_in_loop_continue() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_trycatch_loop", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int count = 0; + for (int i = 0; i < 5; i++) { + try { + if (i == 2) throw new RuntimeException("skip"); + count = count + 1; + } catch (RuntimeException e) { + continue; + } + } + System.out.println(count); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "4", "expected '4' but got: {}", output); +} + +/// Test: Synchronized block with exception — verifies monitorexit on exception path. +#[test] +fn test_compile_e2e_synchronized_exception_path() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_sync_exception", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + Object lock = new Object(); + try { + synchronized (lock) { + throw new RuntimeException("inside sync"); + } + } catch (RuntimeException e) { + System.out.println("caught after sync: " + e.getMessage()); + } + System.out.println("done"); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "caught after sync: inside sync\ndone", + "expected sync exception output but got: {}", + output + ); +} + +/// Test: Int overflow wraps — Java int arithmetic wraps at 32-bit boundaries. +#[test] +fn test_compile_e2e_int_overflow_wraps() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_int_overflow", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int max = 2147483647; + int overflow = max + 1; + System.out.println(overflow); + int min = -2147483648; + int underflow = min - 1; + System.out.println(underflow); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "-2147483648\n2147483647", + "expected overflow values but got: {}", + output + ); +} + +/// Test: Narrowing casts — long-to-byte, long-to-short, double-to-float. +#[test] +fn test_compile_e2e_narrowing_casts() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_narrowing", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + long big = 511L; + byte b = (byte) big; + System.out.println(b); + long val = 40000L; + short s = (short) val; + System.out.println(s); + double d = 1.23456789; + float f = (float) d; + System.out.println(f); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "-1\n-25536\n1.2345679", + "expected narrowing cast output but got: {}", + output + ); +} + +/// Test: Ternary side effect isolation — only the chosen branch's side effect executes. +#[test] +fn test_compile_e2e_ternary_side_effect() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_ternary_side", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[] counter = new int[1]; + counter[0] = 0; + int x = 1; + int result = (x > 0) ? (counter[0] = counter[0] + 10) : (counter[0] = counter[0] - 1); + System.out.println(result); + System.out.println(counter[0]); + x = -1; + result = (x > 0) ? (counter[0] = counter[0] + 10) : (counter[0] = counter[0] - 1); + System.out.println(result); + System.out.println(counter[0]); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "10\n10\n9\n9", + "expected ternary side effect output but got: {}", + output + ); +} + +/// Test: Boolean local from comparison — stores comparison result in a local variable. +#[test] +fn test_compile_e2e_bool_from_comparison() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_bool_cmp", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 7; + boolean isOdd = (x % 2) != 0; + boolean isPositive = x > 0; + if (isOdd && isPositive) { + System.out.println("odd and positive"); + } + boolean b = false; + b = true; + System.out.println(b); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "odd and positive\ntrue", + "expected boolean comparison output but got: {}", + output + ); +} + +/// Test: Var inferred from expression — `var` types resolved across widening chains. +#[test] +fn test_compile_e2e_var_inferred_type() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_var_infer", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 10; + long b = 20L; + var sumLong = a + b; + var doubled = sumLong * 2L; + System.out.println(doubled); + double d = 1.5; + var product = sumLong * d; + System.out.println(product); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "60\n45.0", + "expected var inferred type output but got: {}", + output + ); +} + +/// Test: Zero-length array and length field access. +#[test] +fn test_compile_e2e_array_length() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "e2e_array_len", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[] empty = new int[0]; + System.out.println(empty.length); + String[] strs = new String[3]; + strs[0] = "a"; + strs[1] = "b"; + strs[2] = "c"; + if (strs.length == 3) { + System.out.println("correct length"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "0\ncorrect length", + "expected array length output but got: {}", + output + ); +} diff --git a/tests/compiler/main.rs b/tests/compiler/main.rs new file mode 100644 index 0000000..c4125cf --- /dev/null +++ b/tests/compiler/main.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "compile")] + +use std::fs; +use std::io::{Cursor, Read}; +use std::process::Command; + +use binrw::BinWrite; +use binrw::prelude::*; +use classfile_parser::ClassFile; +use classfile_parser::code_attribute::Instruction; +use classfile_parser::compile::{ + CompileOptions, compile_method_body, generate_bytecode, parse_method_body, prepend_method_body, +}; + +mod e2e; +mod param_access; +mod parser; +mod prepend; +mod stress; + +// --- Test helpers --- + +fn java_available() -> bool { + Command::new("javac") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + && Command::new("java") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +#[allow(unused)] +fn compile_and_load( + test_name: &str, + java_src: &str, + class_name: &str, +) -> (std::path::PathBuf, std::path::PathBuf, ClassFile) { + let tmp_dir = std::env::temp_dir().join(format!("classfile_compile_{}", test_name)); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let compile = Command::new("javac") + .arg("-d") + .arg(&tmp_dir) + .arg(java_src) + .output() + .expect("failed to run javac"); + assert!( + compile.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&compile.stderr) + ); + + let class_path = tmp_dir.join(format!("{}.class", class_name)); + let mut class_bytes = Vec::new(); + std::fs::File::open(&class_path) + .expect("failed to open compiled class") + .read_to_end(&mut class_bytes) + .unwrap(); + let class_file = + ClassFile::read(&mut Cursor::new(&class_bytes)).expect("failed to parse class"); + + (tmp_dir, class_path, class_file) +} + +fn write_and_run( + tmp_dir: &std::path::Path, + class_path: &std::path::Path, + class_file: &ClassFile, + class_name: &str, +) -> String { + let mut out = Cursor::new(Vec::new()); + class_file.write(&mut out).expect("failed to write class"); + fs::write(class_path, out.into_inner()).expect("failed to write class file"); + + let run = Command::new("java") + .arg("-cp") + .arg(tmp_dir) + .arg(class_name) + .output() + .expect("failed to run java"); + assert!( + run.status.success(), + "java failed (exit {}): stderr={}", + run.status, + String::from_utf8_lossy(&run.stderr) + ); + String::from_utf8_lossy(&run.stdout).trim().to_string() +} diff --git a/tests/compiler/param_access.rs b/tests/compiler/param_access.rs new file mode 100644 index 0000000..26f57f9 --- /dev/null +++ b/tests/compiler/param_access.rs @@ -0,0 +1,167 @@ +use super::*; + +/// Like compile_and_load but passes `-g` to javac so LocalVariableTable is present. +#[allow(unused)] +fn compile_and_load_debug( + test_name: &str, + java_src: &str, + class_name: &str, +) -> (std::path::PathBuf, std::path::PathBuf, ClassFile) { + let tmp_dir = std::env::temp_dir().join(format!("classfile_compile_{}", test_name)); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let compile = Command::new("javac") + .arg("-g") + .arg("-d") + .arg(&tmp_dir) + .arg(java_src) + .output() + .expect("failed to run javac"); + assert!( + compile.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&compile.stderr) + ); + + let class_path = tmp_dir.join(format!("{}.class", class_name)); + let mut class_bytes = Vec::new(); + std::fs::File::open(&class_path) + .expect("failed to open compiled class") + .read_to_end(&mut class_bytes) + .unwrap(); + let class_file = + ClassFile::read(&mut Cursor::new(&class_bytes)).expect("failed to parse class"); + + (tmp_dir, class_path, class_file) +} + +#[test] +fn test_param_access_positional_arg0() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "param_positional", + "java-assets/src/ParamAccess.java", + "ParamAccess", + ); + + // main(String[] args): arg0 is args (String[]) + // Use arg0.length to verify array param access works + compile_method_body( + r#"{ System.out.println(arg0.length); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "ParamAccess"); + assert_eq!(output, "0"); +} + +#[test] +fn test_param_access_debug_name() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load_debug( + "param_debug", + "java-assets/src/ParamAccess.java", + "ParamAccess", + ); + + // With -g, the original parameter name "args" is available via LocalVariableTable + compile_method_body( + r#"{ System.out.println(args.length); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "ParamAccess"); + assert_eq!(output, "0"); +} + +#[test] +fn test_param_access_wide_types() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "param_wide", + "java-assets/src/ParamAccess.java", + "ParamAccess", + ); + + // wideParams(int a, long b, String c) + // arg0 = int (slot 0, 1 wide), arg1 = long (slot 1, 2 wide), arg2 = String (slot 3) + compile_method_body( + r#"{ System.out.println(arg0); System.out.println(arg1); System.out.println(arg2); }"#, + &mut class_file, + "wideParams", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // main calls wideParams — the original Java source already does not call it, + // so we provide a main that calls it directly via invokestatic + compile_method_body( + r#"{ ParamAccess.wideParams(42, 123456789L, "hello"); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "ParamAccess"); + assert_eq!(output, "42\n123456789\nhello"); +} + +#[test] +fn test_param_access_instance_method() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "param_instance", + "java-assets/src/ParamAccess.java", + "ParamAccess", + ); + + // instanceMethod(String name): this = slot 0, arg0 = name (slot 1) + compile_method_body( + r#"{ System.out.println(arg0); }"#, + &mut class_file, + "instanceMethod", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Patch main to create instance and call instanceMethod + compile_method_body( + r#"{ + ParamAccess obj = new ParamAccess(); + obj.instanceMethod("world"); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "ParamAccess"); + assert_eq!(output, "world"); +} diff --git a/tests/compiler/parser.rs b/tests/compiler/parser.rs new file mode 100644 index 0000000..ef5e806 --- /dev/null +++ b/tests/compiler/parser.rs @@ -0,0 +1,742 @@ +use super::*; + +// --- Basic parser tests --- + +#[test] +fn test_parse_return_int() { + let stmts = parse_method_body("{ return 42; }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_local_decl() { + let stmts = parse_method_body("{ int x = 10; return x; }").unwrap(); + assert_eq!(stmts.len(), 2); +} + +#[test] +fn test_parse_if_else() { + let stmts = parse_method_body("{ if (x > 0) { return 1; } else { return -1; } }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_while_loop() { + let stmts = parse_method_body("{ while (i < 10) { i = i + 1; } }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_for_loop() { + let stmts = parse_method_body("{ for (int i = 0; i < 10; i++) { sum += i; } }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_method_call() { + let stmts = parse_method_body("{ System.out.println(\"hello\"); }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_string_concat() { + let stmts = parse_method_body("{ String s = \"hello\" + \" world\"; }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_new_object() { + let stmts = parse_method_body("{ StringBuilder sb = new StringBuilder(); }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_comparison_ops() { + let stmts = parse_method_body("{ return a == b && c != d || e < f; }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_ternary() { + let stmts = parse_method_body("{ return x > 0 ? x : -x; }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_array_access() { + let stmts = parse_method_body("{ return arr[0]; }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_cast() { + let stmts = parse_method_body("{ long x = (long) y; }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_throw() { + let stmts = parse_method_body("{ throw new RuntimeException(\"error\"); }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_break_continue() { + let stmts = parse_method_body("{ while (true) { if (done) break; continue; } }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +#[test] +fn test_parse_compound_assign() { + let stmts = parse_method_body("{ x += 1; y -= 2; z *= 3; }").unwrap(); + assert_eq!(stmts.len(), 3); +} + +#[test] +fn test_parse_increment_decrement() { + let stmts = parse_method_body("{ i++; --j; }").unwrap(); + assert_eq!(stmts.len(), 2); +} + +// --- Switch and try-catch parser tests --- + +#[test] +fn test_parse_switch() { + let stmts = parse_method_body( + "{ switch (x) { case 1: return 1; case 2: case 3: return 23; default: return 0; } }", + ) + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::Switch { + cases, + default_body, + .. + } => { + assert_eq!(cases.len(), 2); + assert_eq!(cases[0].values, vec![1]); + assert_eq!(cases[1].values, vec![2, 3]); // fall-through grouping + assert!(default_body.is_some()); + } + other => panic!("expected Switch, got {:?}", other), + } +} + +#[test] +fn test_parse_try_catch() { + let stmts = + parse_method_body("{ try { foo(); } catch (Exception e) { bar(); } finally { baz(); } }") + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::TryCatch { + catches, + finally_body, + .. + } => { + assert_eq!(catches.len(), 1); + assert_eq!(catches[0].var_name, "e"); + assert!(finally_body.is_some()); + } + other => panic!("expected TryCatch, got {:?}", other), + } +} + +#[test] +fn test_parse_try_multiple_catches() { + let stmts = parse_method_body( + "{ try { foo(); } catch (RuntimeException e) { a(); } catch (Exception e) { b(); } }", + ) + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::TryCatch { catches, .. } => { + assert_eq!(catches.len(), 2); + } + other => panic!("expected TryCatch, got {:?}", other), + } +} + +// --- For-each and string concat parser tests --- + +#[test] +fn test_parse_foreach() { + let stmts = parse_method_body("{ for (int x : arr) { sum = sum + x; } }").unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::ForEach { + element_type, + var_name, + .. + } => { + assert_eq!(var_name, "x"); + assert_eq!( + *element_type, + classfile_parser::compile::ast::TypeName::Primitive( + classfile_parser::compile::ast::PrimitiveKind::Int, + ), + ); + } + other => panic!("expected ForEach, got: {:?}", other), + } +} + +#[test] +fn test_parse_string_concat_expr() { + // Verify string concat parses as BinaryOp::Add + let stmts = parse_method_body(r#"{ String s = "hello" + " world"; }"#).unwrap(); + assert_eq!(stmts.len(), 1); +} + +// --- P1 parser tests --- + +#[test] +fn test_parse_multi_catch() { + let stmts = parse_method_body( + "{ try { foo(); } catch (IllegalArgumentException | RuntimeException e) { bar(); } }", + ) + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::TryCatch { catches, .. } => { + assert_eq!(catches.len(), 1); + assert_eq!(catches[0].exception_types.len(), 2); + assert_eq!(catches[0].var_name, "e"); + assert_eq!( + catches[0].exception_types[0], + classfile_parser::compile::ast::TypeName::Class("IllegalArgumentException".into()), + ); + assert_eq!( + catches[0].exception_types[1], + classfile_parser::compile::ast::TypeName::Class("RuntimeException".into()), + ); + } + other => panic!("expected TryCatch, got {:?}", other), + } +} + +#[test] +fn test_parse_multi_catch_three_types() { + let stmts = parse_method_body( + "{ try { foo(); } catch (IOException | SQLException | RuntimeException e) { bar(); } }", + ) + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::TryCatch { catches, .. } => { + assert_eq!(catches.len(), 1); + assert_eq!(catches[0].exception_types.len(), 3); + } + other => panic!("expected TryCatch, got {:?}", other), + } +} + +#[test] +fn test_parse_synchronized() { + let stmts = parse_method_body("{ synchronized (this) { foo(); } }").unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::Synchronized { lock_expr, body } => { + assert!(matches!( + lock_expr, + classfile_parser::compile::ast::CExpr::This + )); + assert_eq!(body.len(), 1); + } + other => panic!("expected Synchronized, got {:?}", other), + } +} + +#[test] +fn test_parse_synchronized_with_expr() { + let stmts = parse_method_body("{ synchronized (lock) { x = 1; y = 2; } }").unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::Synchronized { lock_expr, body } => { + assert!(matches!( + lock_expr, + classfile_parser::compile::ast::CExpr::Ident(_) + )); + assert_eq!(body.len(), 2); + } + other => panic!("expected Synchronized, got {:?}", other), + } +} + +// --- P2 parser tests --- + +#[test] +fn test_parse_var_decl() { + let stmts = parse_method_body("{ var x = 42; }").unwrap(); + assert_eq!(stmts.len(), 1); + // Should parse as LocalDecl with __var__ sentinel type + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { ty, name, init } => { + assert_eq!(name, "x"); + assert!(init.is_some()); + assert_eq!( + *ty, + classfile_parser::compile::ast::TypeName::Class("__var__".into()) + ); + } + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_var_string() { + let stmts = parse_method_body(r#"{ var s = "hello"; }"#).unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { name, init, .. } => { + assert_eq!(name, "s"); + assert!(matches!( + init, + Some(classfile_parser::compile::ast::CExpr::StringLiteral(_)) + )); + } + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_multi_dim_array() { + let stmts = parse_method_body("{ int[][] arr = new int[3][4]; }").unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::NewMultiArray { + element_type, + dimensions, + } => { + assert_eq!( + *element_type, + classfile_parser::compile::ast::TypeName::Primitive( + classfile_parser::compile::ast::PrimitiveKind::Int + ) + ); + assert_eq!(dimensions.len(), 2); + } + other => panic!("expected NewMultiArray, got {:?}", other), + }, + other => panic!("expected LocalDecl with init, got {:?}", other), + } +} + +#[test] +fn test_parse_generic_method() { + // obj.method() should parse without error + let stmts = parse_method_body(r#"{ obj.method(); }"#).unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::ExprStmt( + classfile_parser::compile::ast::CExpr::MethodCall { name, .. }, + ) => { + assert_eq!(name, "method"); + } + other => panic!("expected MethodCall, got {:?}", other), + } +} + +#[test] +fn test_parse_switch_expr() { + let stmts = parse_method_body( + r#"{ + int x = 1; + int r = switch (x) { + case 1 -> 10; + case 2 -> 20; + default -> 0; + }; + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 2); + match &stmts[1] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::SwitchExpr { cases, .. } => { + assert_eq!(cases.len(), 2); + } + other => panic!("expected SwitchExpr, got {:?}", other), + }, + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_switch_expr_multi_case() { + let stmts = parse_method_body( + r#"{ + int r = switch (x) { + case 1, 2 -> 10; + case 3 -> 30; + default -> 0; + }; + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::SwitchExpr { cases, .. } => { + assert_eq!(cases.len(), 2); + assert_eq!(cases[0].values.len(), 2); + } + other => panic!("expected SwitchExpr, got {:?}", other), + }, + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_lambda_no_args() { + let stmts = parse_method_body(r#"{ Runnable r = () -> System.out.println("hi"); }"#).unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::Lambda { params, .. } => { + assert_eq!(params.len(), 0); + } + other => panic!("expected Lambda, got {:?}", other), + }, + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_lambda_typed_param() { + let stmts = parse_method_body(r#"{ var f = (int x) -> x + 1; }"#).unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::Lambda { params, body } => { + assert_eq!(params.len(), 1); + assert_eq!(params[0].name, "x"); + assert!(params[0].ty.is_some()); + assert!(matches!( + body, + classfile_parser::compile::ast::LambdaBody::Expr(_) + )); + } + other => panic!("expected Lambda, got {:?}", other), + }, + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_lambda_block() { + let stmts = parse_method_body(r#"{ var f = (int x) -> { return x + 1; }; }"#).unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::Lambda { params, body } => { + assert_eq!(params.len(), 1); + assert!(matches!( + body, + classfile_parser::compile::ast::LambdaBody::Block(_) + )); + } + other => panic!("expected Lambda, got {:?}", other), + }, + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_method_ref() { + let stmts = parse_method_body(r#"{ var f = String::valueOf; }"#).unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::MethodRef { + class_name, + method_name, + } => { + assert_eq!(class_name, "String"); + assert_eq!(method_name, "valueOf"); + } + other => panic!("expected MethodRef, got {:?}", other), + }, + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +#[test] +fn test_parse_arrow_token() { + // Verify the arrow token works in switch expressions + let stmts = parse_method_body( + r#"{ + int r = switch (1) { + case 1 -> 42; + default -> 0; + }; + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 1); +} + +// --- Parser stress tests --- + +/// Deeply nested expression parsing +#[test] +fn test_parse_stress_deeply_nested_expr() { + let stmts = parse_method_body("{ int x = ((((((1 + 2) * 3) - 4) / 5) % 6) + 7); }").unwrap(); + assert_eq!(stmts.len(), 1); +} + +/// Multiple statements in a row +#[test] +fn test_parse_stress_many_statements() { + let mut body = String::from("{ "); + for i in 0..50 { + body.push_str(&format!("int v{} = {}; ", i, i)); + } + body.push_str(" }"); + let stmts = parse_method_body(&body).unwrap(); + assert_eq!(stmts.len(), 50); +} + +/// Complex type declarations +#[test] +fn test_parse_stress_type_decls() { + let stmts = parse_method_body( + r#"{ + int a = 1; + long b = 2L; + float c = 3.0f; + double d = 4.0; + boolean e = true; + char f = 'x'; + String g = "hi"; + int[] h = new int[5]; + int[][] i = new int[3][4]; + Object j = null; + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 10); +} + +/// Switch with many cases +#[test] +fn test_parse_stress_switch_many_cases() { + let mut body = String::from("{ switch (x) { "); + for i in 0..20 { + body.push_str(&format!("case {}: return {}; ", i, i * 10)); + } + body.push_str("default: return -1; } }"); + let stmts = parse_method_body(&body).unwrap(); + assert_eq!(stmts.len(), 1); +} + +/// Nested switch expressions +#[test] +fn test_parse_stress_nested_switch_expr() { + let stmts = parse_method_body( + r#"{ + int outer = switch (a) { + case 1 -> switch (b) { + case 10 -> 100; + default -> 0; + }; + default -> -1; + }; + }"#, + ); + // This may or may not parse depending on implementation — record result + match stmts { + Ok(s) => assert_eq!(s.len(), 1), + Err(e) => eprintln!("nested switch expr not supported: {}", e), + } +} + +/// For-each with dotted type +#[test] +fn test_parse_stress_foreach_dotted_type() { + let stmts = parse_method_body("{ for (java.lang.String s : list) { System.out.println(s); } }") + .unwrap(); + assert_eq!(stmts.len(), 1); +} + +/// All binary operators in one expression +#[test] +fn test_parse_stress_all_binops() { + let stmts = parse_method_body( + "{ int x = a + b - c * d / e % f; int y = g & h | i ^ j; int z = k << l >> m >>> n; }", + ) + .unwrap(); + assert_eq!(stmts.len(), 3); +} + +/// Chained method calls +#[test] +fn test_parse_stress_chained_calls() { + let stmts = + parse_method_body(r#"{ String s = obj.method1().method2().method3().toString(); }"#) + .unwrap(); + assert_eq!(stmts.len(), 1); +} + +/// Lambda with multiple parameters +#[test] +fn test_parse_stress_lambda_multi_param() { + let stmts = parse_method_body(r#"{ var f = (int a, int b, int c) -> a + b + c; }"#).unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => match expr { + classfile_parser::compile::ast::CExpr::Lambda { params, .. } => { + assert_eq!(params.len(), 3); + } + other => panic!("expected Lambda, got {:?}", other), + }, + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +/// Lambda with block body containing control flow +#[test] +fn test_parse_stress_lambda_complex_body() { + let stmts = parse_method_body( + r#"{ var f = (int x) -> { + if (x > 0) { + return x * 2; + } else { + return -x; + } + }; }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::LocalDecl { + init: Some(expr), .. + } => { + match expr { + classfile_parser::compile::ast::CExpr::Lambda { body, .. } => { + match body { + classfile_parser::compile::ast::LambdaBody::Block(stmts) => { + assert_eq!(stmts.len(), 1); // one if statement + } + _ => panic!("expected block body"), + } + } + other => panic!("expected Lambda, got {:?}", other), + } + } + other => panic!("expected LocalDecl, got {:?}", other), + } +} + +/// Multiple var declarations in sequence +#[test] +fn test_parse_stress_var_sequence() { + let stmts = parse_method_body( + r#"{ + var a = 1; + var b = 2L; + var c = 3.0f; + var d = 4.0; + var e = "str"; + var f = true; + var g = 'x'; + var h = null; + var i = new Object(); + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 9); +} + +/// Comprehensive expression in a single assignment +#[test] +fn test_parse_stress_complex_expr() { + let stmts = + parse_method_body("{ int x = (a > 0 && b < 10) || !(c == d) ? (e + f) * g : h - i / j; }") + .unwrap(); + assert_eq!(stmts.len(), 1); +} + +/// Try-catch with multiple catch blocks and finally +#[test] +fn test_parse_stress_complex_try_catch() { + let stmts = parse_method_body( + r#"{ + try { + foo(); + } catch (IllegalArgumentException e) { + bar(); + } catch (NullPointerException | ArrayIndexOutOfBoundsException e) { + baz(); + } catch (RuntimeException e) { + qux(); + } finally { + cleanup(); + } + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 1); + match &stmts[0] { + classfile_parser::compile::ast::CStmt::TryCatch { + catches, + finally_body, + .. + } => { + assert_eq!(catches.len(), 3); + assert_eq!(catches[1].exception_types.len(), 2); // multi-catch + assert!(finally_body.is_some()); + } + other => panic!("expected TryCatch, got {:?}", other), + } +} + +/// Synchronized with complex expression +#[test] +fn test_parse_stress_synchronized_complex() { + let stmts = parse_method_body( + r#"{ + synchronized (this) { + int x = 1; + for (int i = 0; i < 10; i++) { + x = x + i; + } + if (x > 50) { + throw new RuntimeException("too big"); + } + } + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 1); +} + +/// Generic type parameters in method calls +#[test] +fn test_parse_stress_generic_params() { + let stmts = parse_method_body( + r#"{ + obj.method1(); + obj.method2(); + }"#, + ) + .unwrap(); + assert_eq!(stmts.len(), 2); +} diff --git a/tests/compiler/prepend.rs b/tests/compiler/prepend.rs new file mode 100644 index 0000000..e6600d9 --- /dev/null +++ b/tests/compiler/prepend.rs @@ -0,0 +1,358 @@ +use super::*; + +// --- Prepend mode tests --- + +#[test] +fn test_prepend_println() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "prepend_println", + "java-assets/src/PrependTest.java", + "PrependTest", + ); + + // Replace main to just print "original", then prepend "before" + compile_method_body( + r#"{ System.out.println("original"); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + prepend_method_body( + r#"{ System.out.println("before"); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "PrependTest"); + assert_eq!(output, "before\noriginal"); +} + +#[test] +fn test_prepend_with_param_access() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "prepend_param", + "java-assets/src/PrependTest.java", + "PrependTest", + ); + + // Prepend code that prints the parameter before original body runs + prepend_method_body( + r#"{ System.out.println(arg0); }"#, + &mut class_file, + "withParams", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Patch main to call withParams + compile_method_body( + r#"{ PrependTest.withParams("world"); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "PrependTest"); + assert_eq!(output, "world\nhello world"); +} + +#[test] +fn test_prepend_with_local_variable() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "prepend_local", + "java-assets/src/PrependTest.java", + "PrependTest", + ); + + // Prepend code that declares a local variable + prepend_method_body( + r#"{ int y = 99; System.out.println("y=" + y); }"#, + &mut class_file, + "withLocal", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Patch main to call withLocal + compile_method_body( + r#"{ PrependTest.withLocal(); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "PrependTest"); + assert_eq!(output, "y=99\nx=10"); +} + +#[test] +fn test_prepend_to_method_with_try_catch() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "prepend_trycatch", + "java-assets/src/PrependTest.java", + "PrependTest", + ); + + prepend_method_body( + r#"{ System.out.println("before try"); }"#, + &mut class_file, + "withTryCatch", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Patch main to call withTryCatch + compile_method_body( + r#"{ PrependTest.withTryCatch(); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "PrependTest"); + assert_eq!(output, "before try\ntry"); +} + +#[test] +fn test_prepend_with_branches() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "prepend_branches", + "java-assets/src/PrependTest.java", + "PrependTest", + ); + + // Prepend an if/else that has branch targets (requires StackMapTable merge) + prepend_method_body( + r#"{ if (arg0 > 0) { System.out.println("positive"); } else { System.out.println("non-positive"); } }"#, + &mut class_file, + "withBranch", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Patch main to call withBranch + compile_method_body( + r#"{ PrependTest.withBranch(5); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "PrependTest"); + assert_eq!(output, "positive\nn=5"); +} + +#[test] +fn test_prepend_macro() { + if !java_available() { + eprintln!("SKIP: java/javac not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "prepend_macro", + "java-assets/src/PrependTest.java", + "PrependTest", + ); + + // Replace main to just print "original", then prepend with macro + compile_method_body( + r#"{ System.out.println("original"); }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + classfile_parser::prepend_method!( + class_file, + "main", + r#"{ System.out.println("macro prepend"); }"# + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "PrependTest"); + assert_eq!(output, "macro prepend\noriginal"); +} + +// --- StackMapTable edge case tests --- + +/// Regression test: wide local (long) followed by non-wide local + branch. +/// Verifies StackMapTable encoding doesn't include explicit Top continuation slots. +#[test] +fn test_wide_local_then_narrow_with_branch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "wide_narrow_branch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + long x = 100L; + int y = 42; + if (y > 10) { + System.out.println(x + y); + } else { + System.out.println(y); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "142"); +} + +/// StackMapTable edge case: many statements before a branch pushes the frame delta past 63, +/// which should trigger SameFrameExtended instead of SameFrame. +#[test] +fn test_stackmap_large_delta_extended_frame() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "smt_extended", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + // Generate enough bytecode before the branch to push the delta past 63. + // Each println is ~8 bytes (getstatic 3 + ldc 2 + invokevirtual 3 = 8). + // We need > 63 bytes before the if statement to force SameFrameExtended. + compile_method_body( + r#"{ + System.out.println("a"); + System.out.println("b"); + System.out.println("c"); + System.out.println("d"); + System.out.println("e"); + System.out.println("f"); + System.out.println("g"); + System.out.println("h"); + System.out.println("i"); + int x = 1; + if (x > 0) { + System.out.println("yes"); + } else { + System.out.println("no"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert!( + output.ends_with("yes"), + "expected 'yes' at end, got: {}", + output + ); +} + +/// StackMapTable edge case: prepending code that pushes existing frames past the +/// SameFrame threshold, verifying re-encoding from SameFrame to SameFrameExtended. +#[test] +fn test_prepend_stackmap_reencoding_threshold() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = + compile_and_load("smt_reenc", "java-assets/src/HelloWorld.java", "HelloWorld"); + + // First replace with code that has a branch near the SameFrame limit + compile_method_body( + r#"{ + int x = 1; + if (x > 0) { + System.out.println("original-yes"); + } else { + System.out.println("original-no"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Now prepend enough code to push the existing frames past offset 63 + let mut opts = CompileOptions::default(); + opts.insert_mode = classfile_parser::compile::InsertMode::Prepend; + compile_method_body( + r#"{ + System.out.println("p1"); + System.out.println("p2"); + System.out.println("p3"); + System.out.println("p4"); + System.out.println("p5"); + System.out.println("p6"); + System.out.println("p7"); + System.out.println("p8"); + }"#, + &mut class_file, + "main", + None, + &opts, + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert!( + output.ends_with("original-yes"), + "expected 'original-yes' at end, got: {}", + output + ); +} diff --git a/tests/compiler/stress.rs b/tests/compiler/stress.rs new file mode 100644 index 0000000..4a5d0dd --- /dev/null +++ b/tests/compiler/stress.rs @@ -0,0 +1,1705 @@ +use super::*; + +// --------------------------------------------------------------------------- +// Category 1: Complex control flow +// --------------------------------------------------------------------------- + +/// Nested loops with break/continue interacting across levels +#[test] +fn test_stress_nested_loop_break_continue() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_nested_break", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int result = 0; + for (int i = 0; i < 10; i++) { + if (i == 3) continue; + if (i == 7) break; + int j = 0; + while (j < 5) { + if (j == 2) { + j++; + continue; + } + result = result + 1; + j++; + } + // i runs 0,1,2,4,5,6 (skip 3, break at 7) = 6 iterations + // j runs 0,1,3,4 (skip 2) = 4 per outer = 24 total + } + System.out.println(result); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "24"); +} + +/// Deeply nested if-else chain (fizzbuzz) +#[test] +fn test_stress_fizzbuzz() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_fizzbuzz", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + for (int i = 1; i <= 20; i++) { + if (i % 15 == 0) { + System.out.println("FizzBuzz"); + } else if (i % 3 == 0) { + System.out.println("Fizz"); + } else if (i % 5 == 0) { + System.out.println("Buzz"); + } else { + System.out.println(i); + } + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + let expected = "1\n2\nFizz\n4\nBuzz\nFizz\n7\n8\nFizz\nBuzz\n11\nFizz\n13\n14\nFizzBuzz\n16\n17\nFizz\n19\nBuzz"; + assert_eq!(output, expected); +} + +/// Switch with many sparse cases (triggers lookupswitch) +#[test] +fn test_stress_switch_lookup() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_switch_lookup", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 500; + switch (x) { + case 1: + System.out.println("one"); + break; + case 100: + System.out.println("hundred"); + break; + case 500: + System.out.println("five-hundred"); + break; + case 9999: + System.out.println("nine-thousand"); + break; + default: + System.out.println("other"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "five-hundred"); +} + +/// Switch with dense cases (triggers tableswitch) +#[test] +fn test_stress_switch_table() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_switch_table", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int sum = 0; + for (int i = 0; i < 8; i++) { + switch (i) { + case 0: sum = sum + 1; break; + case 1: sum = sum + 2; break; + case 2: sum = sum + 4; break; + case 3: sum = sum + 8; break; + case 4: sum = sum + 16; break; + case 5: sum = sum + 32; break; + case 6: sum = sum + 64; break; + case 7: sum = sum + 128; break; + default: break; + } + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // 1+2+4+8+16+32+64+128 = 255 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "255"); +} + +// --------------------------------------------------------------------------- +// Category 2: Complex arithmetic and type mixing +// --------------------------------------------------------------------------- + +/// Mixed-type arithmetic with widening conversions +#[test] +fn test_stress_mixed_type_arithmetic() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_mixed_arith", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int i = 10; + long l = 20L; + float f = 1.5f; + double d = 2.5; + + // int + long = long (i2l widening) + long sum1 = i + l; + System.out.println(sum1); + + // int + float = float (i2f widening) + float sum2 = i + f; + System.out.println(sum2); + + // long + double = double (l2d widening) + double sum3 = l + d; + System.out.println(sum3); + + // int * double = double (i2d widening) + double prod = i * d; + System.out.println(prod); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "30\n11.5\n22.5\n25.0"); +} + +/// Compound assignment with different types +#[test] +fn test_stress_compound_assign_types() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_compound_types", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 100; + a += 50; + a -= 25; + a *= 2; + a /= 5; + a %= 7; + System.out.println(a); + + long b = 1000000000L; + b += 2000000000L; + System.out.println(b); + + double c = 10.0; + c *= 3.14; + System.out.println(c); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // a: (100+50-25)*2/5 = 250/5 = 50; 50%7 = 1 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "1\n3000000000\n31.400000000000002"); +} + +/// Bitwise operations +#[test] +fn test_stress_bitwise_ops() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_bitwise", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 0xFF; + int b = 0x0F; + System.out.println(a & b); + System.out.println(a | b); + System.out.println(a ^ b); + System.out.println(~b); + + // Shift operations + int c = 1; + System.out.println(c << 10); + System.out.println(1024 >> 3); + System.out.println(-1 >>> 28); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // 0xFF & 0x0F = 15, 0xFF | 0x0F = 255, 0xFF ^ 0x0F = 240, ~0x0F = -16 + // 1 << 10 = 1024, 1024 >> 3 = 128, -1 >>> 28 = 15 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "15\n255\n240\n-16\n1024\n128\n15"); +} + +/// Cast chain: int -> long -> double -> int +#[test] +fn test_stress_cast_chain() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_cast_chain", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 42; + long b = (long) a; + double c = (double) b; + float f = (float) c; + int d = (int) f; + System.out.println(d); + + // Truncation: double -> int + double pi = 3.14159; + int truncated = (int) pi; + System.out.println(truncated); + + // Large long -> int truncation + long big = 3000000000L; + int small = (int) big; + System.out.println(small); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + // 42 survives round-trip, pi truncates to 3, 3000000000 wraps to -1294967296 + assert_eq!(output, "42\n3\n-1294967296"); +} + +/// Pre/post increment/decrement combinations +#[test] +fn test_stress_increment_decrement() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_incdec", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 10; + a++; + ++a; + System.out.println(a); + + a--; + --a; + System.out.println(a); + + // For loop with increment + int sum = 0; + for (int i = 0; i < 5; i++) { + sum += i; + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "12\n10\n10"); +} + +// --------------------------------------------------------------------------- +// Category 3: String operations +// --------------------------------------------------------------------------- + +/// String concat with all primitive types +#[test] +fn test_stress_string_concat_all_types() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_str_all", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int i = 42; + long l = 123456789L; + float f = 3.14f; + double d = 2.718; + boolean b = true; + char c = 'X'; + String s = "i=" + i + " l=" + l + " f=" + f + " d=" + d + " b=" + b + " c=" + c; + System.out.println(s); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "i=42 l=123456789 f=3.14 d=2.718 b=true c=X"); +} + +/// String concat in loop (builds string progressively) +#[test] +fn test_stress_string_concat_loop() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_str_loop", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + String s = ""; + for (int i = 0; i < 5; i++) { + s = s + i; + } + System.out.println(s); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "01234"); +} + +// --------------------------------------------------------------------------- +// Category 4: Arrays +// --------------------------------------------------------------------------- + +/// Array operations: create, fill, read, foreach +#[test] +fn test_stress_array_comprehensive() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_arr_comp", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[] arr = new int[10]; + for (int i = 0; i < 10; i++) { + arr[i] = i * i; + } + int sum = 0; + for (int x : arr) { + sum = sum + x; + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // 0+1+4+9+16+25+36+49+64+81 = 285 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "285"); +} + +/// Multi-dimensional array: 3x3 matrix multiplication-ish +#[test] +fn test_stress_multi_dim_array_compute() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_mdarray", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[][] m = new int[3][3]; + int val = 1; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + m[i][j] = val; + val++; + } + } + // Compute trace (diagonal sum) + int trace = m[0][0] + m[1][1] + m[2][2]; + System.out.println(trace); + + // Print all values + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + System.out.println(m[i][j]); + } + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Matrix: [[1,2,3],[4,5,6],[7,8,9]], trace = 1+5+9 = 15 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "15\n1\n2\n3\n4\n5\n6\n7\n8\n9"); +} + +/// Typed arrays: long[], double[], boolean[], char[] +#[test] +fn test_stress_typed_arrays() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_typed_arr", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + long[] longs = new long[2]; + longs[0] = 9999999999L; + longs[1] = 1L; + System.out.println(longs[0] + longs[1]); + + double[] doubles = new double[2]; + doubles[0] = 1.1; + doubles[1] = 2.2; + System.out.println(doubles[0] + doubles[1]); + + boolean[] bools = new boolean[2]; + bools[0] = true; + bools[1] = false; + System.out.println(bools[0]); + System.out.println(bools[1]); + + char[] chars = new char[3]; + chars[0] = 'H'; + chars[1] = 'i'; + chars[2] = '!'; + System.out.println("" + chars[0] + chars[1] + chars[2]); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "10000000000\n3.3000000000000003\ntrue\nfalse\nHi!"); +} + +// --------------------------------------------------------------------------- +// Category 5: Exception handling +// --------------------------------------------------------------------------- + +/// Try-catch-finally with exception in different places +#[test] +fn test_stress_try_catch_complex() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_trycatch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + // Test 1: exception caught, finally runs + try { + System.out.println("try1"); + throw new RuntimeException("err"); + } catch (RuntimeException e) { + System.out.println("catch1"); + } finally { + System.out.println("finally1"); + } + + // Test 2: no exception, finally still runs + try { + System.out.println("try2"); + } catch (RuntimeException e) { + System.out.println("catch2-WRONG"); + } finally { + System.out.println("finally2"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "try1\ncatch1\nfinally1\ntry2\nfinally2"); +} + +/// Multi-catch with multiple exception types +#[test] +fn test_stress_multi_catch_variants() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_multi_catch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + // First, throw IllegalArgumentException + try { + throw new IllegalArgumentException("bad arg"); + } catch (IllegalArgumentException | NullPointerException e) { + System.out.println("caught-multi-1"); + } + + // Then throw NullPointerException + try { + throw new NullPointerException("null"); + } catch (IllegalArgumentException | NullPointerException e) { + System.out.println("caught-multi-2"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "caught-multi-1\ncaught-multi-2"); +} + +/// Nested try-catch +#[test] +fn test_stress_nested_try_catch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_nested_try", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + try { + System.out.println("outer-try"); + try { + System.out.println("inner-try"); + throw new RuntimeException("inner"); + } catch (RuntimeException e) { + System.out.println("inner-catch"); + } + System.out.println("after-inner"); + } catch (Exception e) { + System.out.println("outer-catch-WRONG"); + } finally { + System.out.println("outer-finally"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, + "outer-try\ninner-try\ninner-catch\nafter-inner\nouter-finally" + ); +} + +// --------------------------------------------------------------------------- +// Category 6: Ternary and logical operators +// --------------------------------------------------------------------------- + +/// Nested ternary expressions +#[test] +fn test_stress_nested_ternary() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_ternary", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + for (int i = 0; i < 5; i++) { + String s = i < 2 ? "low" : (i < 4 ? "mid" : "high"); + System.out.println(s); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "low\nlow\nmid\nmid\nhigh"); +} + +/// Complex boolean short-circuit evaluation +#[test] +fn test_stress_short_circuit() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_shortcircuit", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 5; + int b = 10; + int c = 15; + + // Complex short-circuit: (a > 3 && b < 20) || (c == 0) + if ((a > 3 && b < 20) || c == 0) { + System.out.println("yes1"); + } else { + System.out.println("no1"); + } + + // Should short-circuit: false && ... should not evaluate right + if (a > 100 && b > 0) { + System.out.println("yes2"); + } else { + System.out.println("no2"); + } + + // Should short-circuit: true || ... should not evaluate right + if (a == 5 || b > 1000) { + System.out.println("yes3"); + } else { + System.out.println("no3"); + } + + // Negation + if (!(a > 10)) { + System.out.println("yes4"); + } else { + System.out.println("no4"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "yes1\nno2\nyes3\nyes4"); +} + +// --------------------------------------------------------------------------- +// Category 7: Object creation and method calls +// --------------------------------------------------------------------------- + +/// StringBuilder method chaining +#[test] +fn test_stress_stringbuilder_chain() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_sb_chain", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + StringBuilder sb = new StringBuilder(); + sb.append("Hello"); + sb.append(" "); + sb.append("World"); + sb.append("!"); + String result = sb.toString(); + System.out.println(result); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "Hello World!"); +} + +/// Creating and using ArrayList — requires new constant pool entries for +/// classes not already referenced in HelloWorld.class (expected to fail +/// until the compiler supports adding new class/method refs to the pool). +#[test] +fn test_stress_arraylist() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_arraylist", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + java.util.ArrayList list = new java.util.ArrayList(); + list.add("alpha"); + list.add("beta"); + list.add("gamma"); + int sz = list.size(); + System.out.println(sz); + Object item = list.get(1); + System.out.println(item); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .expect("ArrayList compilation should succeed with descriptor inference"); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "3\nbeta", "ArrayList add/size/get should work"); +} + +/// StringBuilder.length() — tests well-known method descriptor heuristic +#[test] +fn test_stress_sb_length() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_sb_length", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + StringBuilder sb = new StringBuilder(); + sb.append("hello"); + sb.append(" world"); + System.out.println(sb.length()); + System.out.println(sb.toString()); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .expect("StringBuilder length compilation should succeed"); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "11\nhello world", "sb.length() should return 11"); +} + +/// instanceof + cast +#[test] +fn test_stress_instanceof_cast() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_instanceof", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + Object obj = "hello"; + if (obj instanceof String) { + String s = (String) obj; + System.out.println("is string: " + s); + } else { + System.out.println("not string"); + } + + Object num = new Integer(42); + if (num instanceof String) { + System.out.println("WRONG"); + } else { + System.out.println("not a string"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "is string: hello\nnot a string"); +} + +// --------------------------------------------------------------------------- +// Category 8: var keyword stress +// --------------------------------------------------------------------------- + +/// var with complex type inference +#[test] +fn test_stress_var_inference() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_var_infer", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + var i = 42; + var l = 100L; + var f = 1.5f; + var d = 2.718; + var s = "hello"; + var b = true; + var c = 'Z'; + + System.out.println(i); + System.out.println(l); + System.out.println(f); + System.out.println(d); + System.out.println(s); + System.out.println(b); + System.out.println(c); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "42\n100\n1.5\n2.718\nhello\ntrue\nZ"); +} + +/// var with new object +#[test] +fn test_stress_var_new_object() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_var_newobj", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + var sb = new StringBuilder(); + sb.append("var"); + sb.append("-works"); + System.out.println(sb.toString()); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "var-works"); +} + +// --------------------------------------------------------------------------- +// Category 9: Switch expressions stress +// --------------------------------------------------------------------------- + +/// Switch expression as method argument +#[test] +fn test_stress_switch_expr_as_arg() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_switch_arg", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + for (int i = 1; i <= 4; i++) { + int val = switch (i) { + case 1 -> 100; + case 2 -> 200; + case 3 -> 300; + default -> -1; + }; + System.out.println(val); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "100\n200\n300\n-1"); +} + +/// Switch expression with multi-value cases +#[test] +fn test_stress_switch_expr_multi_value() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_switch_multi", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + for (int day = 1; day <= 7; day++) { + int type = switch (day) { + case 1, 7 -> 0; + case 2, 3, 4, 5, 6 -> 1; + default -> -1; + }; + System.out.println(type); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // day 1=0(weekend), 2-6=1(weekday), 7=0(weekend) + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "0\n1\n1\n1\n1\n1\n0"); +} + +// --------------------------------------------------------------------------- +// Category 10: Complex combined scenarios +// --------------------------------------------------------------------------- + +/// Bubble sort implementation +#[test] +fn test_stress_bubble_sort() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_bubblesort", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[] arr = new int[5]; + arr[0] = 5; + arr[1] = 3; + arr[2] = 8; + arr[3] = 1; + arr[4] = 9; + + // Bubble sort + for (int i = 0; i < 5; i++) { + for (int j = 0; j < 4 - i; j++) { + if (arr[j] > arr[j + 1]) { + int temp = arr[j]; + arr[j] = arr[j + 1]; + arr[j + 1] = temp; + } + } + } + + // Print sorted + for (int x : arr) { + System.out.println(x); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "1\n3\n5\n8\n9"); +} + +/// Fibonacci with array memoization +#[test] +fn test_stress_fibonacci() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_fibonacci", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int n = 20; + long[] fib = new long[21]; + fib[0] = 0L; + fib[1] = 1L; + for (int i = 2; i <= n; i++) { + fib[i] = fib[i - 1] + fib[i - 2]; + } + System.out.println(fib[10]); + System.out.println(fib[20]); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "55\n6765"); +} + +/// GCD computation using while loop +#[test] +fn test_stress_gcd() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_gcd", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int a = 48; + int b = 18; + while (b != 0) { + int temp = b; + b = a % b; + a = temp; + } + System.out.println(a); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "6"); +} + +/// Power of 2 check using bitwise +#[test] +fn test_stress_power_of_two() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_pow2", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + for (int n = 1; n <= 16; n++) { + boolean isPow2 = (n & (n - 1)) == 0; + if (isPow2) { + System.out.println(n); + } + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "1\n2\n4\n8\n16"); +} + +/// Complex nested feature combination: synchronized + try-catch + for-each + var +#[test] +fn test_stress_feature_combo() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_combo", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + Object lock = new Object(); + var result = 0; + synchronized (lock) { + int[] values = new int[5]; + for (int i = 0; i < 5; i++) { + values[i] = (i + 1) * 10; + } + try { + for (int v : values) { + result = result + v; + } + } catch (Exception e) { + System.out.println("error"); + } + } + System.out.println(result); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // 10+20+30+40+50 = 150 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "150"); +} + +/// Null handling +#[test] +fn test_stress_null_handling() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_null", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + String s = null; + if (s == null) { + System.out.println("is null"); + } + s = "not null"; + if (s != null) { + System.out.println(s); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "is null\nnot null"); +} + +// --------------------------------------------------------------------------- +// Category 11: Edge cases and boundary values +// --------------------------------------------------------------------------- + +/// Large int constants +#[test] +fn test_stress_large_constants() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_large_const", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int max = 2147483647; + int min = -2147483648; + System.out.println(max); + System.out.println(min); + + long lmax = 9223372036854775807L; + System.out.println(lmax); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "2147483647\n-2147483648\n9223372036854775807"); +} + +/// Empty loops and blocks +#[test] +fn test_stress_empty_constructs() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_empty", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + // Empty for loop + int i = 0; + for (; i < 10; i++) { + } + System.out.println(i); + + // Empty while loop + while (i > 10) { + } + + // Empty block + { + } + + // Nested empty blocks + { + { + { + } + } + } + System.out.println("done"); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "10\ndone"); +} + +/// Char arithmetic +#[test] +fn test_stress_char_arithmetic() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_char_arith", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + char c = 'A'; + // Char is an int in the JVM + int code = (int) c; + System.out.println(code); + + // Character iteration + for (char ch = 'a'; ch <= 'e'; ch++) { + System.out.println(ch); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "65\na\nb\nc\nd\ne"); +} + +// --------------------------------------------------------------------------- +// Category 12: StackMapTable verification (full verification enabled) +// --------------------------------------------------------------------------- + +/// Complex control flow with full verification +#[test] +fn test_stress_verified_complex_flow() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_verified_flow", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int sum = 0; + for (int i = 0; i < 10; i++) { + if (i % 2 == 0) { + sum = sum + i; + } else { + sum = sum - 1; + } + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + // Even: 0+2+4+6+8 = 20, Odd: -5 = 15 + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "15"); +} + +/// Switch with full verification +#[test] +fn test_stress_verified_switch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_verified_switch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int x = 3; + switch (x) { + case 1: + System.out.println("one"); + break; + case 2: + System.out.println("two"); + break; + case 3: + System.out.println("three"); + break; + default: + System.out.println("other"); + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "three"); +} + +/// For-each with full verification — currently exposes a StackMapTable bug where +/// parameter types in locals are emitted as Null instead of their actual types. +/// The for-each loop generates extra locals (array copy, length, index) that the +/// stack map tracker doesn't fully account for with the pre-existing parameter types. +#[test] +fn test_stress_verified_foreach() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_verified_foreach", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int[] arr = new int[4]; + arr[0] = 10; + arr[1] = 20; + arr[2] = 30; + arr[3] = 40; + int sum = 0; + for (int x : arr) { + sum = sum + x; + } + System.out.println(sum); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .expect("compilation should succeed"); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "100", "for-each should pass JVM verification"); +} + +// --------------------------------------------------------------------------- +// Category 13: Synchronized stress +// --------------------------------------------------------------------------- + +/// Multiple synchronized blocks +#[test] +fn test_stress_multiple_synchronized() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_multi_sync", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + Object lock1 = new Object(); + Object lock2 = new Object(); + int value = 0; + + synchronized (lock1) { + value = value + 10; + } + synchronized (lock2) { + value = value + 20; + } + synchronized (lock1) { + value = value * 2; + } + System.out.println(value); + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "60"); +} + +// --------------------------------------------------------------------------- +// Category 15: Sieve of Eratosthenes (ultimate algorithm test) +// --------------------------------------------------------------------------- + +#[test] +fn test_stress_sieve_of_eratosthenes() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "stress_sieve", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + compile_method_body( + r#"{ + int limit = 30; + boolean[] sieve = new boolean[31]; + // false = prime, true = composite (default false) + sieve[0] = true; + sieve[1] = true; + + for (int i = 2; i * i <= limit; i++) { + if (!sieve[i]) { + for (int j = i * i; j <= limit; j += i) { + sieve[j] = true; + } + } + } + + // Print primes, one per line + for (int i = 2; i <= limit; i++) { + if (!sieve[i]) { + System.out.println(i); + } + } + }"#, + &mut class_file, + "main", + None, + &CompileOptions::default(), + ) + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!(output, "2\n3\n5\n7\n11\n13\n17\n19\n23\n29"); +} diff --git a/tests/decompile.rs b/tests/decompile.rs new file mode 100644 index 0000000..75160f7 --- /dev/null +++ b/tests/decompile.rs @@ -0,0 +1,321 @@ +#![cfg(feature = "decompile")] + +use std::fs; +use std::io::Cursor; + +use binrw::prelude::*; +use classfile_parser::ClassFile; +use classfile_parser::decompile::cfg; +use classfile_parser::decompile::descriptor; +use classfile_parser::decompile::stack_sim; +use classfile_parser::decompile::structuring; +use classfile_parser::decompile::{self, DecompileOptions, Decompiler, RenderConfig}; + +fn load_class(name: &str) -> ClassFile { + let path = format!("java-assets/compiled-classes/{}", name); + let bytes = fs::read(&path).unwrap_or_else(|_| panic!("Failed to read {}", path)); + ClassFile::read(&mut Cursor::new(bytes)).unwrap_or_else(|_| panic!("Failed to parse {}", path)) +} + +// ---- Descriptor tests ---- + +#[test] +fn test_descriptor_parsing() { + assert_eq!( + descriptor::parse_method_descriptor("(Ljava/lang/String;I)V") + .unwrap() + .0 + .len(), + 2 + ); + assert_eq!( + descriptor::parse_type_descriptor("[Ljava/lang/Object;").unwrap(), + descriptor::JvmType::Array(Box::new(descriptor::JvmType::Reference( + "java/lang/Object".into() + ))) + ); +} + +// ---- Phase 1: CFG tests ---- + +#[test] +fn test_cfg_basic_class() { + let class = load_class("BasicClass.class"); + for method in &class.methods { + if let Some(code) = method.code() { + let cfg = cfg::build_cfg(code); + assert!(!cfg.blocks.is_empty(), "CFG should have at least one block"); + assert!(cfg.blocks.contains_key(&0), "CFG should start at address 0"); + } + } +} + +#[test] +fn test_cfg_factorial() { + let class = load_class("Factorial.class"); + let method = class + .find_method("factorial") + .expect("factorial method should exist"); + let code = method.code().expect("factorial should have code"); + let cfg = cfg::build_cfg(code); + + // Factorial has branches, so it should have multiple blocks + assert!( + cfg.blocks.len() > 1, + "Factorial CFG should have multiple blocks, got {}", + cfg.blocks.len() + ); + + // Should have at least one conditional branch + let has_conditional = cfg.blocks.values().any(|b| { + matches!( + b.terminator, + classfile_parser::decompile::cfg_types::Terminator::ConditionalBranch { .. } + ) + }); + assert!( + has_conditional, + "Factorial CFG should have conditional branches" + ); +} + +#[test] +fn test_cfg_instructions_switches() { + let class = load_class("Instructions.class"); + for method in &class.methods { + if let Some(code) = method.code() { + let cfg = cfg::build_cfg(code); + assert!(!cfg.blocks.is_empty()); + } + } +} + +#[test] +fn test_cfg_dot_output() { + let class = load_class("BasicClass.class"); + if let Some(code) = class.methods.first().and_then(|m| m.code()) { + let cfg = cfg::build_cfg(code); + let dot = cfg.to_dot(); + assert!(dot.contains("digraph CFG")); + assert!(dot.contains("->") || cfg.blocks.len() <= 1); + } +} + +#[test] +fn test_cfg_reverse_postorder() { + let class = load_class("Factorial.class"); + let method = class.find_method("factorial").expect("factorial method"); + let code = method.code().expect("code"); + let cfg = cfg::build_cfg(code); + let rpo = cfg.reverse_postorder(); + assert!(!rpo.is_empty()); + assert_eq!(rpo[0], 0, "RPO should start with entry block"); +} + +// ---- Phase 2: Stack simulation tests ---- + +#[test] +fn test_stack_sim_basic_class() { + let class = load_class("BasicClass.class"); + for method in &class.methods { + if let Some(code) = method.code() { + let is_static = method + .access_flags + .contains(classfile_parser::method_info::MethodAccessFlags::STATIC); + let cfg = cfg::build_cfg(code); + let simulated = + stack_sim::simulate_all_blocks(&cfg, &class.const_pool, code, is_static); + assert!(!simulated.is_empty(), "Should have simulated blocks"); + } + } +} + +#[test] +fn test_stack_sim_hello_world() { + let class = load_class("HelloWorld.class"); + let method = class.find_method("main").expect("main method should exist"); + let code = method.code().expect("main should have code"); + let cfg = cfg::build_cfg(code); + let simulated = stack_sim::simulate_all_blocks(&cfg, &class.const_pool, code, true); + + // main method should produce at least a method call statement (System.out.println) + let total_stmts: usize = simulated.iter().map(|b| b.statements.len()).sum(); + assert!( + total_stmts > 0, + "Should have at least one statement in main()" + ); +} + +// ---- Phase 3: Structuring tests ---- + +#[test] +fn test_structuring_basic() { + let class = load_class("BasicClass.class"); + for method in &class.methods { + if let Some(code) = method.code() { + let is_static = method + .access_flags + .contains(classfile_parser::method_info::MethodAccessFlags::STATIC); + let cfg = cfg::build_cfg(code); + let simulated = + stack_sim::simulate_all_blocks(&cfg, &class.const_pool, code, is_static); + let body = structuring::structure_method(&cfg, &simulated, &class.const_pool); + assert!(!body.statements.is_empty() || code.code.is_empty()); + } + } +} + +// ---- Phase 6: Full decompilation tests ---- + +#[test] +fn test_decompile_basic_class() { + let class = load_class("BasicClass.class"); + let result = decompile::decompile(&class).expect("decompilation should succeed"); + assert!( + result.contains("class"), + "Output should contain 'class' keyword" + ); + println!("--- BasicClass decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_hello_world() { + let class = load_class("HelloWorld.class"); + let result = decompile::decompile(&class).expect("decompilation should succeed"); + assert!( + result.contains("class HelloWorld"), + "Should contain class name" + ); + assert!(result.contains("main"), "Should contain main method"); + println!("--- HelloWorld decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_factorial() { + let class = load_class("Factorial.class"); + let result = decompile::decompile(&class).expect("decompilation should succeed"); + assert!( + result.contains("factorial"), + "Should contain factorial method" + ); + println!("--- Factorial decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_instructions() { + let class = load_class("Instructions.class"); + let result = decompile::decompile(&class).expect("decompilation should succeed"); + assert!(result.contains("class"), "Should produce output"); + println!("--- Instructions decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_record() { + let class = load_class("RecordExample.class"); + let result = decompile::decompile(&class).expect("decompilation should succeed"); + assert!(result.contains("record"), "Should contain 'record' keyword"); + println!("--- RecordExample decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_sealed() { + let class = load_class("SealedExample.class"); + let result = decompile::decompile(&class).expect("decompilation should succeed"); + assert!(result.contains("sealed"), "Should contain 'sealed' keyword"); + assert!( + result.contains("permits"), + "Should contain 'permits' keyword" + ); + println!("--- SealedExample decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_annotations() { + let class = load_class("Annotations.class"); + let result = decompile::decompile(&class).expect("decompilation should succeed"); + assert!(result.contains("@"), "Should contain annotation markers"); + println!("--- Annotations decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_with_options() { + let class = load_class("BasicClass.class"); + let options = DecompileOptions { + render_config: RenderConfig { + indent: " ".into(), + max_line_width: 80, + use_var: false, + include_synthetic: true, + }, + include_synthetic: true, + ..Default::default() + }; + let decompiler = Decompiler::new(options); + let result = decompiler + .decompile(&class) + .expect("decompilation should succeed"); + assert!(result.contains("class"), "Should produce output"); +} + +#[test] +fn test_decompile_inner_classes() { + let outer = load_class("NestExample.class"); + let inner = load_class("NestExample$Inner.class"); + let decompiler = Decompiler::new(DecompileOptions::default()); + let result = decompiler + .decompile_with_inner_classes(&outer, &[&inner]) + .expect("decompilation should succeed"); + assert!(result.contains("class"), "Should produce output"); + println!("--- NestExample + Inner decompilation ---\n{}", result); +} + +#[test] +fn test_decompile_single_method() { + let class = load_class("HelloWorld.class"); + let decompiler = Decompiler::new(DecompileOptions::default()); + let result = decompiler + .decompile_method(&class, "main") + .expect("method decompilation should succeed"); + assert!(result.contains("main"), "Should contain the method"); + println!("--- HelloWorld.main() ---\n{}", result); +} + +#[test] +fn test_decompile_all_test_classes() { + // Ensure we can decompile every test class without panicking + let classes = [ + "BasicClass.class", + "HelloWorld.class", + "Factorial.class", + "Instructions.class", + "Annotations.class", + "DeprecatedAnnotation.class", + "InnerClasses.class", + "LocalVariableTable.class", + "BootstrapMethods.class", + "RecordExample.class", + "SealedExample.class", + "SealedChild1.class", + "SealedChild2.class", + "NestExample.class", + "NestExample$Inner.class", + "UnicodeStrings.class", + ]; + + for class_name in &classes { + let class = load_class(class_name); + match decompile::decompile(&class) { + Ok(source) => { + dbg!(&source); + assert!( + !source.is_empty(), + "{} should produce non-empty output", + class_name + ); + } + Err(e) => { + panic!("{} failed to decompile: {}", class_name, e); + } + } + } +} diff --git a/tests/e2e_patch.rs b/tests/e2e_patch.rs new file mode 100644 index 0000000..9700671 --- /dev/null +++ b/tests/e2e_patch.rs @@ -0,0 +1,511 @@ +use std::fs; +use std::io::{Cursor, Read}; +use std::path::Path; +use std::process::Command; + +use binrw::BinWrite; +use binrw::prelude::*; +use classfile_parser::attribute_info::AttributeInfoVariant; +use classfile_parser::code_attribute::Instruction; +use classfile_parser::constant_info::{ConstantInfo, StringConstant, Utf8Constant}; +use classfile_parser::{ClassAccessFlags, ClassFile}; + +// --- Helpers --- + +fn java_available() -> bool { + Command::new("javac") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + && Command::new("java") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn compile_and_load( + test_name: &str, + java_src: &str, + class_name: &str, +) -> (std::path::PathBuf, std::path::PathBuf, ClassFile) { + let tmp_dir = std::env::temp_dir().join(format!("classfile_e2e_{}", test_name)); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let compile = Command::new("javac") + .arg("-d") + .arg(&tmp_dir) + .arg(java_src) + .output() + .expect("failed to run javac"); + assert!( + compile.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&compile.stderr) + ); + + let class_path = tmp_dir.join(format!("{}.class", class_name)); + let mut class_bytes = Vec::new(); + std::fs::File::open(&class_path) + .expect("failed to open compiled class") + .read_to_end(&mut class_bytes) + .unwrap(); + let class_file = + ClassFile::read(&mut Cursor::new(&class_bytes)).expect("failed to parse class"); + + (tmp_dir, class_path, class_file) +} + +fn write_and_run( + tmp_dir: &Path, + class_path: &Path, + class_file: &ClassFile, + class_name: &str, +) -> String { + let mut out = Cursor::new(Vec::new()); + class_file.write(&mut out).expect("failed to write class"); + fs::write(class_path, out.into_inner()).expect("failed to write class file"); + + let run = Command::new("java") + .arg("-cp") + .arg(tmp_dir) + .arg(class_name) + .output() + .expect("failed to run java"); + assert!( + run.status.success(), + "java failed (exit {}): stderr={}", + run.status, + String::from_utf8_lossy(&run.stderr) + ); + String::from_utf8_lossy(&run.stdout).trim().to_string() +} + +/// Returns (method_index, code_attribute_index) for the named method. +fn find_code_attr(class_file: &ClassFile, method_name: &str) -> (usize, usize) { + let method_idx = class_file + .methods + .iter() + .position(|m| { + matches!( + &class_file.const_pool[(m.name_index - 1) as usize], + ConstantInfo::Utf8(u) if u.utf8_string == method_name + ) + }) + .unwrap_or_else(|| panic!("method '{}' not found", method_name)); + + let attr_idx = class_file.methods[method_idx] + .attributes + .iter() + .position(|a| matches!(a.info_parsed, Some(AttributeInfoVariant::Code(_)))) + .expect("no Code attribute found"); + + (method_idx, attr_idx) +} + +// --- Tests --- + +/// Test 1: Patch a float constant in the constant pool. +/// SimpleMath.floatMath divides 1.0 / 3.0 = 0.33333334. +/// Change 3.0 to 6.0, so the result becomes 1.0 / 6.0 = 0.16666667. +#[test] +fn test_e2e_patch_float_constant() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "float_const", + "java-assets/src/SimpleMath.java", + "SimpleMath", + ); + + let mut patched = false; + for entry in &mut class_file.const_pool { + if let ConstantInfo::Float(f) = entry { + if f.value == 3.0 { + f.value = 6.0; + patched = true; + } + } + } + assert!(patched, "could not find float 3.0 in constant pool"); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "SimpleMath"); + assert!( + output.contains("0.16666667"), + "expected 0.16666667 in output: {}", + output + ); +} + +/// Test 2: Patch an instruction operand via attribute reserialization. +/// SimpleMath.intMath has `int a = 10` which compiles to `bipush 10`. +/// Change to `bipush 5` so c = 5 + 20 = 25. +#[test] +fn test_e2e_patch_instruction_operand() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "instr_operand", + "java-assets/src/SimpleMath.java", + "SimpleMath", + ); + + let (mi, ai) = find_code_attr(&class_file, "intMath"); + + { + let code = match &mut class_file.methods[mi].attributes[ai].info_parsed { + Some(AttributeInfoVariant::Code(c)) => c, + _ => unreachable!(), + }; + let mut found = false; + for instr in &mut code.code { + if *instr == Instruction::Bipush(10) { + *instr = Instruction::Bipush(5); + found = true; + break; + } + } + assert!(found, "could not find Bipush(10) in intMath"); + } + + class_file.methods[mi].attributes[ai] + .sync_from_parsed() + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "SimpleMath"); + assert!(output.contains("25"), "expected 25 in output: {}", output); +} + +/// Test 3: Replace an instruction opcode via attribute reserialization. +/// SimpleMath.intMath has `int c = a + b` which compiles to `iadd`. +/// Replace with `isub` so c = 10 - 20 = -10. +#[test] +fn test_e2e_replace_instruction_opcode() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "instr_opcode", + "java-assets/src/SimpleMath.java", + "SimpleMath", + ); + + let (mi, ai) = find_code_attr(&class_file, "intMath"); + + { + let code = match &mut class_file.methods[mi].attributes[ai].info_parsed { + Some(AttributeInfoVariant::Code(c)) => c, + _ => unreachable!(), + }; + let mut found = false; + for instr in &mut code.code { + if *instr == Instruction::Iadd { + *instr = Instruction::Isub; + found = true; + break; + } + } + assert!(found, "could not find Iadd in intMath"); + } + + class_file.methods[mi].attributes[ai] + .sync_from_parsed() + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "SimpleMath"); + assert!(output.contains("-10"), "expected -10 in output: {}", output); +} + +/// Test 4: Add new constant pool entries and redirect an ldc instruction. +/// HelloWorld.main loads "Hello World!" via ldc. +/// Add "Injected!" to the pool and redirect the ldc to it. +#[test] +fn test_e2e_add_constant_and_redirect_ldc() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "redirect_ldc", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + // Add new Utf8 constant + let utf8_cp_index = (class_file.const_pool.len() + 1) as u16; + class_file.const_pool.push(ConstantInfo::Utf8(Utf8Constant { + utf8_string: String::from("Injected!"), + })); + + // Add new String constant referencing the Utf8 + let string_cp_index = (class_file.const_pool.len() + 1) as u16; + class_file + .const_pool + .push(ConstantInfo::String(StringConstant { + string_index: utf8_cp_index, + })); + + class_file.sync_counts(); + + assert!( + string_cp_index <= 255, + "string constant pool index {} exceeds u8 range", + string_cp_index + ); + + let (mi, ai) = find_code_attr(&class_file, "main"); + + { + let code = match &mut class_file.methods[mi].attributes[ai].info_parsed { + Some(AttributeInfoVariant::Code(c)) => c, + _ => unreachable!(), + }; + let mut found = false; + for instr in &mut code.code { + if let Instruction::Ldc(_) = instr { + *instr = Instruction::Ldc(string_cp_index as u8); + found = true; + break; + } + } + assert!(found, "could not find Ldc in main"); + } + + class_file.methods[mi].attributes[ai] + .sync_from_parsed() + .unwrap(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "Injected!", + "expected 'Injected!' but got: {}", + output + ); +} + +/// Test 5: Patch class access flags. +/// Toggle FINAL on HelloWorld. It should still load and run +/// (FINAL just prevents subclassing). +#[test] +fn test_e2e_patch_access_flags() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "access_flags", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + assert!( + !class_file.access_flags.contains(ClassAccessFlags::FINAL), + "HelloWorld should not already be FINAL" + ); + + class_file.access_flags |= ClassAccessFlags::FINAL; + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "Hello World!", + "expected 'Hello World!' but got: {}", + output + ); + + // Re-parse and verify FINAL flag is set + let mut class_bytes = Vec::new(); + std::fs::File::open(&class_path) + .unwrap() + .read_to_end(&mut class_bytes) + .unwrap(); + let reparsed = ClassFile::read(&mut Cursor::new(&class_bytes)).expect("failed to re-parse"); + assert!( + reparsed.access_flags.contains(ClassAccessFlags::FINAL), + "FINAL flag should be set in re-parsed class" + ); +} + +/// Test 6: Remove a method and its call site. +/// SimpleMath has intMath, floatMath, and main. +/// Remove intMath from the methods array and replace its call in main +/// with nop instructions. Output should only contain "Float math:" line. +#[test] +fn test_e2e_remove_method() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "remove_method", + "java-assets/src/SimpleMath.java", + "SimpleMath", + ); + + let (main_mi, main_ai) = find_code_attr(&class_file, "main"); + + // Replace the first 4 instructions of main (the "Integer math:" println + intMath call) + // with nop instructions. Byte sizes: getstatic(3) + ldc(2) + invokevirtual(3) + invokestatic(3) = 11 nops. + { + let code = match &mut class_file.methods[main_mi].attributes[main_ai].info_parsed { + Some(AttributeInfoVariant::Code(c)) => c, + _ => unreachable!(), + }; + + assert!( + matches!(&code.code[0], Instruction::Getstatic(_)), + "expected Getstatic as first instruction, got {:?}", + &code.code[0] + ); + assert!( + matches!(&code.code[1], Instruction::Ldc(_)), + "expected Ldc as second instruction, got {:?}", + &code.code[1] + ); + assert!( + matches!(&code.code[2], Instruction::Invokevirtual(_)), + "expected Invokevirtual as third instruction, got {:?}", + &code.code[2] + ); + assert!( + matches!(&code.code[3], Instruction::Invokestatic(_)), + "expected Invokestatic as fourth instruction, got {:?}", + &code.code[3] + ); + + // Replace first 4 instructions with 11 nops (matching total byte count: 3+2+3+3=11) + let nops: Vec = vec![Instruction::Nop; 11]; + code.code.splice(0..4, nops); + } + + class_file.methods[main_mi].attributes[main_ai] + .sync_from_parsed() + .unwrap(); + + // Remove intMath method + let int_math_idx = class_file + .methods + .iter() + .position(|m| { + matches!( + &class_file.const_pool[(m.name_index - 1) as usize], + ConstantInfo::Utf8(u) if u.utf8_string == "intMath" + ) + }) + .expect("intMath method not found"); + class_file.methods.remove(int_math_idx); + class_file.sync_counts(); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "SimpleMath"); + assert!( + !output.contains("Integer math:"), + "should not contain 'Integer math:', got: {}", + output + ); + assert!( + output.contains("Float math:"), + "expected 'Float math:' in output: {}", + output + ); + assert!( + output.contains("0.33333"), + "expected float result in output: {}", + output + ); +} + +#[test] +fn test_e2e_patch_string_contents() { + // Fail if javac/java aren't available + if Command::new("javac").arg("-version").output().is_err() { + eprintln!("skipping test_e2e_patch: javac not found"); + assert!(false); + return; + } + if Command::new("java").arg("-version").output().is_err() { + eprintln!("skipping test_e2e_patch: java not found"); + assert!(false); + return; + } + + let tmp_dir = std::env::temp_dir().join("classfile_e2e_patch_test"); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).expect("failed to create temp dir"); + eprintln!("Writing to {}", &tmp_dir.display()); + + let compile = Command::new("javac") + .arg("-d") + .arg(&tmp_dir) + .arg("java-assets/src/HelloWorld.java") + .output() + .expect("failed to run javac"); + + assert!( + compile.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&compile.stderr) + ); + + let class_path = tmp_dir.join("HelloWorld.class"); + let mut class_bytes = Vec::new(); + std::fs::File::open(&class_path) + .expect("failed to open compiled class") + .read_to_end(&mut class_bytes) + .unwrap(); + + let mut class_file = + ClassFile::read(&mut Cursor::new(&class_bytes)).expect("failed to parse class"); + + let mut patched = false; + for entry in &mut class_file.const_pool { + if let ConstantInfo::Utf8(utf8) = entry { + if utf8.utf8_string == "Hello World!" { + utf8.utf8_string = String::from("Patched!"); + patched = true; + } + } + } + assert!(patched, "could not find 'Hello World!' in constant pool"); + + let mut out = Cursor::new(Vec::new()); + class_file + .write(&mut out) + .expect("failed to write patched class"); + fs::write(&class_path, out.into_inner()).expect("failed to write patched class file"); + + let run = Command::new("java") + .arg("-cp") + .arg(&tmp_dir) + .arg("HelloWorld") + .output() + .expect("failed to run java"); + + assert!( + run.status.success(), + "java failed (exit {}): {}", + run.status, + String::from_utf8_lossy(&run.stderr) + ); + + let stdout = String::from_utf8_lossy(&run.stdout); + assert_eq!( + stdout.trim(), + "Patched!", + "expected 'Patched!' but got: {:?}", + stdout + ); + + let _ = fs::remove_dir_all(&tmp_dir); +} diff --git a/tests/jar_patch.rs b/tests/jar_patch.rs new file mode 100644 index 0000000..e578e5f --- /dev/null +++ b/tests/jar_patch.rs @@ -0,0 +1,635 @@ +#![cfg(all(feature = "compile", feature = "jar-utils"))] + +use std::fs; +use std::io::Write; +use std::process::Command; + +use classfile_parser::compile::CompileOptions; +use classfile_parser::jar_patch::{self, JarPatchError}; +use classfile_parser::jar_utils::JarFile; +// Macros are at crate root via #[macro_export] +use classfile_parser::patch_jar; +use classfile_parser::patch_jar_class; +use classfile_parser::patch_jar_method; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn read_class_bytes(name: &str) -> Vec { + let path = format!("java-assets/compiled-classes/{name}"); + std::fs::read(&path).unwrap_or_else(|e| panic!("failed to read {path}: {e}")) +} + +fn build_jar(entries: &[(&str, &[u8])]) -> JarFile { + use std::io::Cursor; + use zip::CompressionMethod; + use zip::write::SimpleFileOptions; + + let mut buf = Cursor::new(Vec::new()); + { + let mut writer = zip::ZipWriter::new(&mut buf); + let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated); + for (name, data) in entries { + writer.start_file(*name, options).unwrap(); + writer.write_all(data).unwrap(); + } + writer.finish().unwrap(); + } + JarFile::from_bytes(&buf.into_inner()).unwrap() +} + +fn java_available() -> bool { + Command::new("javac") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + && Command::new("java") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn compile_java( + test_name: &str, + java_src: &str, + class_name: &str, +) -> (std::path::PathBuf, Vec) { + let tmp_dir = std::env::temp_dir().join(format!("classfile_jar_patch_{test_name}")); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let javac = Command::new("javac") + .arg("-d") + .arg(&tmp_dir) + .arg(java_src) + .output() + .expect("failed to run javac"); + assert!( + javac.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&javac.stderr) + ); + + let class_path = tmp_dir.join(format!("{class_name}.class")); + let class_bytes = fs::read(&class_path).unwrap(); + (tmp_dir, class_bytes) +} + +fn run_jar(jar: &JarFile, class_name: &str, test_name: &str) -> String { + let tmp_dir = std::env::temp_dir().join(format!("classfile_jar_patch_run_{test_name}")); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let jar_path = tmp_dir.join("test.jar"); + jar.save(&jar_path).unwrap(); + + let run = Command::new("java") + .arg("-cp") + .arg(&jar_path) + .arg(class_name) + .output() + .expect("failed to run java"); + let _ = fs::remove_dir_all(&tmp_dir); + assert!( + run.status.success(), + "java failed (exit {}): stderr={}", + run.status, + String::from_utf8_lossy(&run.stderr) + ); + String::from_utf8_lossy(&run.stdout).trim().to_string() +} + +// --------------------------------------------------------------------------- +// Error case tests +// --------------------------------------------------------------------------- + +#[test] +fn test_patch_jar_method_class_not_found() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + let result = jar_patch::patch_jar_method( + &mut jar, + "DoesNotExist.class", + "main", + r#"{ return; }"#, + &CompileOptions::default(), + ); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), JarPatchError::Jar(_))); +} + +#[test] +fn test_patch_jar_method_method_not_found() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + let result = jar_patch::patch_jar_method( + &mut jar, + "HelloWorld.class", + "doesNotExist", + r#"{ return; }"#, + &CompileOptions::default(), + ); + assert!(result.is_err()); + match result.unwrap_err() { + JarPatchError::Compile(classfile_parser::compile::CompileError::MethodNotFound { + name, + }) => { + assert_eq!(name, "doesNotExist"); + } + other => panic!("expected MethodNotFound, got: {other:?}"), + } +} + +// --------------------------------------------------------------------------- +// Function tests +// --------------------------------------------------------------------------- + +#[test] +fn test_patch_jar_method_basic() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + // Patch should succeed — we're replacing main with a simple body + jar_patch::patch_jar_method( + &mut jar, + "HelloWorld.class", + "main", + r#"{ System.out.println("patched"); }"#, + &CompileOptions::default(), + ) + .unwrap(); + + // The class entry should be updated + let updated_bytes = jar.get_entry("HelloWorld.class").unwrap(); + assert_ne!( + updated_bytes, + &class_bytes[..], + "class bytes should differ after patching" + ); +} + +#[test] +fn test_patch_jar_class_multiple_methods() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + jar_patch::patch_jar_class( + &mut jar, + "HelloWorld.class", + &[("main", r#"{ System.out.println("one"); }"#)], + &CompileOptions::default(), + ) + .unwrap(); + + let updated_bytes = jar.get_entry("HelloWorld.class").unwrap(); + assert_ne!(updated_bytes, &class_bytes[..]); +} + +// --------------------------------------------------------------------------- +// Macro syntax tests +// --------------------------------------------------------------------------- + +#[test] +fn test_macro_patch_jar_method() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + patch_jar_method!(jar, "HelloWorld.class", "main", r#"{ return; }"#).unwrap(); + + let updated = jar.get_entry("HelloWorld.class").unwrap(); + assert_ne!(updated, &class_bytes[..]); +} + +#[test] +fn test_macro_patch_jar_method_no_verify() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + patch_jar_method!(jar, "HelloWorld.class", "main", r#"{ return; }"#, no_verify).unwrap(); +} + +#[test] +fn test_macro_patch_jar_class() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + patch_jar_class!(jar, "HelloWorld.class", { + "main" => r#"{ return; }"#, + }) + .unwrap(); +} + +#[test] +fn test_macro_patch_jar_class_no_verify() { + let class_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &class_bytes)]); + + patch_jar_class!(jar, "HelloWorld.class", no_verify, { + "main" => r#"{ return; }"#, + }) + .unwrap(); +} + +#[test] +fn test_macro_patch_jar_multi_class() { + let hello_bytes = read_class_bytes("HelloWorld.class"); + let basic_bytes = read_class_bytes("BasicClass.class"); + let mut jar = build_jar(&[ + ("HelloWorld.class", &hello_bytes), + ("BasicClass.class", &basic_bytes), + ]); + + patch_jar!(jar, { + "HelloWorld.class" => { + "main" => r#"{ return; }"#, + }, + }) + .unwrap(); +} + +#[test] +fn test_macro_patch_jar_multi_class_no_verify() { + let hello_bytes = read_class_bytes("HelloWorld.class"); + let mut jar = build_jar(&[("HelloWorld.class", &hello_bytes)]); + + patch_jar!(jar, no_verify, { + "HelloWorld.class" => { + "main" => r#"{ return; }"#, + }, + }) + .unwrap(); +} + +// --------------------------------------------------------------------------- +// E2E tests (require javac + java) +// --------------------------------------------------------------------------- + +#[test] +fn test_e2e_patch_jar_hello_world() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "e2e_jar_hello", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + System.out.println("jar-patched!"); + }"# + ) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "e2e_jar_hello"); + assert_eq!(output, "jar-patched!"); +} + +#[test] +fn test_e2e_patch_jar_multi_method() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "e2e_jar_multi", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_class!(jar, "HelloWorld.class", { + "main" => r#"{ + System.out.println("multi-patched"); + }"#, + }) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "e2e_jar_multi"); + assert_eq!(output, "multi-patched"); +} + +#[test] +fn test_e2e_patch_jar_macro() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "e2e_jar_macro", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar!(jar, { + "HelloWorld.class" => { + "main" => r#"{ + System.out.println("full-macro"); + }"#, + }, + }) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "e2e_jar_macro"); + assert_eq!(output, "full-macro"); +} + +#[test] +fn test_e2e_patch_jar_save_and_load() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "e2e_jar_save", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + System.out.println("saved-and-loaded"); + }"# + ) + .unwrap(); + + // Save to disk, re-open, verify contents survived round-trip + let tmp_dir = std::env::temp_dir().join("classfile_jar_patch_save_test"); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + let jar_path = tmp_dir.join("test.jar"); + + jar.save(&jar_path).unwrap(); + let reloaded = JarFile::open(&jar_path).unwrap(); + + let output = run_jar(&reloaded, "HelloWorld", "e2e_jar_save"); + assert_eq!(output, "saved-and-loaded"); + + let _ = fs::remove_dir_all(&tmp_dir); +} + +// --------------------------------------------------------------------------- +// Stress tests: JAR patching with complex method bodies +// --------------------------------------------------------------------------- + +/// Jar patch with array operations, for-each, and string concat +#[test] +fn test_stress_jar_patch_complex_body() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "stress_jar_complex", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + int[] arr = new int[10]; + for (int i = 0; i < 10; i++) { + arr[i] = i * i; + } + int sum = 0; + for (int x : arr) { + sum = sum + x; + } + System.out.println("sum=" + sum); + }"# + ) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "stress_jar_complex"); + assert_eq!(output, "sum=285"); +} + +/// Jar patch with try-catch-finally +#[test] +fn test_stress_jar_patch_try_catch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "stress_jar_trycatch", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + try { + System.out.println("before"); + throw new RuntimeException("test"); + } catch (RuntimeException e) { + System.out.println("caught"); + } finally { + System.out.println("finally"); + } + }"# + ) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "stress_jar_trycatch"); + assert_eq!(output, "before\ncaught\nfinally"); +} + +/// Jar patch with switch expression and var keyword +#[test] +fn test_stress_jar_patch_modern_java() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "stress_jar_modern", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + var x = 3; + var result = switch (x) { + case 1 -> "one"; + case 2 -> "two"; + case 3 -> "three"; + default -> "other"; + }; + System.out.println(result); + }"# + ) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "stress_jar_modern"); + assert_eq!(output, "three"); +} + +/// Jar patch with bubble sort algorithm +#[test] +fn test_stress_jar_patch_bubble_sort() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "stress_jar_sort", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + int[] arr = new int[5]; + arr[0] = 42; + arr[1] = 17; + arr[2] = 99; + arr[3] = 3; + arr[4] = 55; + for (int i = 0; i < 5; i++) { + for (int j = 0; j < 4 - i; j++) { + if (arr[j] > arr[j + 1]) { + int temp = arr[j]; + arr[j] = arr[j + 1]; + arr[j + 1] = temp; + } + } + } + for (int x : arr) { + System.out.println(x); + } + }"# + ) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "stress_jar_sort"); + assert_eq!(output, "3\n17\n42\n55\n99"); +} + +/// Jar patch with synchronized + exception handling +#[test] +fn test_stress_jar_patch_sync_trycatch() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "stress_jar_sync_try", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + Object lock = new Object(); + synchronized (lock) { + try { + System.out.println("in-sync-try"); + } catch (Exception e) { + System.out.println("error"); + } finally { + System.out.println("in-sync-finally"); + } + } + }"# + ) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "stress_jar_sync_try"); + assert_eq!(output, "in-sync-try\nin-sync-finally"); +} + +/// Jar patch with multi-dimensional arrays +#[test] +fn test_stress_jar_patch_multi_dim() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + let (_tmp, class_bytes) = compile_java( + "stress_jar_multidim", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + let mut jar = JarFile::new(); + jar.set_entry("HelloWorld.class", class_bytes); + + patch_jar_method!( + jar, + "HelloWorld.class", + "main", + r#"{ + int[][] grid = new int[3][3]; + int v = 1; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + grid[i][j] = v; + v++; + } + } + int sum = grid[0][0] + grid[1][1] + grid[2][2]; + System.out.println(sum); + }"# + ) + .unwrap(); + + let output = run_jar(&jar, "HelloWorld", "stress_jar_multidim"); + // Trace: 1 + 5 + 9 = 15 + assert_eq!(output, "15"); +} diff --git a/tests/jar_utils.rs b/tests/jar_utils.rs new file mode 100644 index 0000000..fc0c3b9 --- /dev/null +++ b/tests/jar_utils.rs @@ -0,0 +1,295 @@ +#![cfg(feature = "jar-utils")] + +extern crate classfile_parser; + +use std::io::{Cursor, Write}; + +use classfile_parser::ClassFile; +use classfile_parser::jar_utils::{JarFile, JarManifest}; + +use binrw::BinRead; +use zip::CompressionMethod; +use zip::write::SimpleFileOptions; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Read a .class file from disk as raw bytes. +fn read_class_bytes(name: &str) -> Vec { + let path = format!("java-assets/compiled-classes/{name}"); + std::fs::read(&path).unwrap_or_else(|e| panic!("failed to read {path}: {e}")) +} + +/// Build a minimal JAR in-memory with the given entries. +fn build_jar(entries: &[(&str, &[u8])]) -> Vec { + let mut buf = Cursor::new(Vec::new()); + { + let mut writer = zip::ZipWriter::new(&mut buf); + let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated); + for (name, data) in entries { + writer.start_file(*name, options).unwrap(); + writer.write_all(data).unwrap(); + } + writer.finish().unwrap(); + } + buf.into_inner() +} + +// =========================================================================== +// Manifest tests +// =========================================================================== + +#[test] +fn test_manifest_parse() { + let manifest_data = b"Manifest-Version: 1.0\r\nCreated-By: test\r\n\r\nName: com/example/Foo.class\r\nSHA-256-Digest: abc123\r\n"; + let manifest = JarManifest::parse(manifest_data).unwrap(); + + assert_eq!(manifest.main_attr("Manifest-Version"), Some("1.0")); + assert_eq!(manifest.main_attr("Created-By"), Some("test")); + + // Case-insensitive lookup + assert_eq!(manifest.main_attr("manifest-version"), Some("1.0")); + assert_eq!(manifest.main_attr("CREATED-BY"), Some("test")); + + let section = manifest.entry_section("com/example/Foo.class").unwrap(); + assert_eq!(section.get("SHA-256-Digest"), Some("abc123")); +} + +#[test] +fn test_manifest_continuation_lines() { + // A long value that wraps past 72 bytes + let long_value = "a]".repeat(40); // 80 chars + + // Manually wrap: first line up to 72 bytes, then continuation + let full_line = format!("Long-Key: {}", long_value); + let first_72 = &full_line[..72]; + let rest = &full_line[72..]; + let wrapped = format!("Manifest-Version: 1.0\r\n{}\r\n {}\r\n", first_72, rest); + + let manifest = JarManifest::parse(wrapped.as_bytes()).unwrap(); + assert_eq!(manifest.main_attr("Long-Key"), Some(long_value.as_str())); +} + +#[test] +fn test_manifest_round_trip() { + let original = b"Manifest-Version: 1.0\r\nCreated-By: test\r\n\r\nName: com/example/A.class\r\nDigest: aaa\r\n\r\nName: com/example/B.class\r\nDigest: bbb\r\n"; + let manifest = JarManifest::parse(original).unwrap(); + let serialized = manifest.to_bytes(); + let reparsed = JarManifest::parse(&serialized).unwrap(); + + assert_eq!(manifest.main_attributes, reparsed.main_attributes); + assert_eq!(manifest.entries.len(), reparsed.entries.len()); + for (name, attrs) in &manifest.entries { + assert_eq!(reparsed.entry_section(name).unwrap(), attrs); + } +} + +#[test] +fn test_default_manifest() { + let m = JarManifest::default_manifest(); + assert_eq!(m.main_attr("Manifest-Version"), Some("1.0")); + assert!(m.entries.is_empty()); +} + +// =========================================================================== +// JarFile I/O tests +// =========================================================================== + +#[test] +fn test_read_and_list_entries() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_jar(&[ + ("com/example/BasicClass.class", &class_bytes), + ("META-INF/MANIFEST.MF", b"Manifest-Version: 1.0\r\n"), + ("readme.txt", b"hello"), + ]); + + let jar = JarFile::from_bytes(&jar_bytes).unwrap(); + + let names: Vec<&str> = jar.entry_names().collect(); + assert_eq!(names.len(), 3); + assert!(names.contains(&"com/example/BasicClass.class")); + assert!(names.contains(&"META-INF/MANIFEST.MF")); + assert!(names.contains(&"readme.txt")); + + let class_names: Vec<&str> = jar.class_names().collect(); + assert_eq!(class_names.len(), 1); + assert_eq!(class_names[0], "com/example/BasicClass.class"); +} + +#[test] +fn test_round_trip_jar() { + let class_bytes = read_class_bytes("BasicClass.class"); + let manifest = b"Manifest-Version: 1.0\r\n"; + let jar_bytes = build_jar(&[ + ("com/example/BasicClass.class", &class_bytes), + ("META-INF/MANIFEST.MF", manifest), + ]); + + let jar = JarFile::from_bytes(&jar_bytes).unwrap(); + let rewritten = jar.to_bytes().unwrap(); + let jar2 = JarFile::from_bytes(&rewritten).unwrap(); + + // Same entries + let names1: Vec<&str> = jar.entry_names().collect(); + let names2: Vec<&str> = jar2.entry_names().collect(); + assert_eq!(names1, names2); + + // Same content + for name in &names1 { + assert_eq!(jar.get_entry(name), jar2.get_entry(name)); + } +} + +#[test] +fn test_add_remove_entries() { + let jar_bytes = build_jar(&[("a.txt", b"aaa")]); + let mut jar = JarFile::from_bytes(&jar_bytes).unwrap(); + + assert!(jar.contains_entry("a.txt")); + assert!(!jar.contains_entry("b.txt")); + + jar.set_entry("b.txt", b"bbb".to_vec()); + assert!(jar.contains_entry("b.txt")); + assert_eq!(jar.get_entry("b.txt"), Some(b"bbb".as_slice())); + + let removed = jar.remove_entry("a.txt"); + assert_eq!(removed, Some(b"aaa".to_vec())); + assert!(!jar.contains_entry("a.txt")); +} + +#[test] +fn test_open_and_save() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_jar(&[("com/example/BasicClass.class", &class_bytes)]); + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.jar"); + + // Write manually then read with open() + std::fs::write(&path, &jar_bytes).unwrap(); + let jar = JarFile::open(&path).unwrap(); + assert!(jar.contains_entry("com/example/BasicClass.class")); + + // Save and re-open + let path2 = dir.path().join("test2.jar"); + jar.save(&path2).unwrap(); + let jar2 = JarFile::open(&path2).unwrap(); + assert_eq!( + jar.get_entry("com/example/BasicClass.class"), + jar2.get_entry("com/example/BasicClass.class"), + ); +} + +// =========================================================================== +// ClassFile integration tests +// =========================================================================== + +#[test] +fn test_parse_class_from_jar() { + let class_bytes = read_class_bytes("BasicClass.class"); + + // Parse directly for reference + let direct = ClassFile::read(&mut Cursor::new(&class_bytes)).unwrap(); + + // Parse from JAR + let jar_bytes = build_jar(&[("BasicClass.class", &class_bytes)]); + let jar = JarFile::from_bytes(&jar_bytes).unwrap(); + let from_jar = jar.parse_class("BasicClass.class").unwrap(); + + assert_eq!(direct.major_version, from_jar.major_version); + assert_eq!(direct.minor_version, from_jar.minor_version); + assert_eq!(direct.const_pool_size, from_jar.const_pool_size); + assert_eq!(direct.methods_count, from_jar.methods_count); + assert_eq!(direct.fields_count, from_jar.fields_count); +} + +#[test] +fn test_set_class_round_trip() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_jar(&[("BasicClass.class", &class_bytes)]); + let mut jar = JarFile::from_bytes(&jar_bytes).unwrap(); + + let class_file = jar.parse_class("BasicClass.class").unwrap(); + jar.set_class("BasicClass.class", &class_file).unwrap(); + + let reparsed = jar.parse_class("BasicClass.class").unwrap(); + assert_eq!(class_file.major_version, reparsed.major_version); + assert_eq!(class_file.const_pool_size, reparsed.const_pool_size); + assert_eq!(class_file.methods_count, reparsed.methods_count); +} + +#[test] +fn test_patch_class_in_jar() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_jar(&[("BasicClass.class", &class_bytes)]); + let mut jar = JarFile::from_bytes(&jar_bytes).unwrap(); + + let mut class_file = jar.parse_class("BasicClass.class").unwrap(); + + // Add a UTF-8 constant as a marker + let marker = "jar_utils_test_marker"; + class_file.add_utf8(marker); + class_file.sync_counts(); + + jar.set_class("BasicClass.class", &class_file).unwrap(); + + // Re-parse and verify the marker exists + let reparsed = jar.parse_class("BasicClass.class").unwrap(); + assert!( + reparsed.find_utf8_index(marker).is_some(), + "marker constant should be present after patching" + ); +} + +#[test] +fn test_parse_all_classes() { + let basic_bytes = read_class_bytes("BasicClass.class"); + let factorial_bytes = read_class_bytes("Factorial.class"); + let jar_bytes = build_jar(&[ + ("com/example/BasicClass.class", &basic_bytes), + ("com/example/Factorial.class", &factorial_bytes), + ("readme.txt", b"not a class"), + ]); + + let jar = JarFile::from_bytes(&jar_bytes).unwrap(); + let results = jar.parse_all_classes(); + + assert_eq!(results.len(), 2); + for (name, result) in &results { + assert!(name.ends_with(".class")); + assert!(result.is_ok(), "failed to parse {name}"); + } +} + +// =========================================================================== +// Manifest-in-JAR tests +// =========================================================================== + +#[test] +fn test_manifest_in_jar() { + let manifest_data = b"Manifest-Version: 1.0\r\nMain-Class: com.example.Main\r\n"; + let jar_bytes = build_jar(&[ + ("META-INF/MANIFEST.MF", manifest_data), + ("readme.txt", b"hi"), + ]); + + let mut jar = JarFile::from_bytes(&jar_bytes).unwrap(); + + // Read manifest + let manifest = jar.manifest().unwrap().unwrap(); + assert_eq!(manifest.main_attr("Main-Class"), Some("com.example.Main")); + + // Modify and store + let mut updated = manifest.clone(); + updated.set_main_attr("Main-Class", "com.example.Other"); + jar.set_manifest(&updated); + + // Re-read + let re_manifest = jar.manifest().unwrap().unwrap(); + assert_eq!( + re_manifest.main_attr("Main-Class"), + Some("com.example.Other") + ); +} diff --git a/tests/module_attribute.rs b/tests/module_attribute.rs new file mode 100644 index 0000000..c84c872 --- /dev/null +++ b/tests/module_attribute.rs @@ -0,0 +1,106 @@ +use std::io::Cursor; + +use binrw::prelude::*; +use classfile_parser::ClassFile; +use classfile_parser::attribute_info::AttributeInfoVariant; +use classfile_parser::constant_info::ConstantInfo; + +fn lookup_string(c: &ClassFile, index: u16) -> Option { + let con = &c.const_pool[(index - 1) as usize]; + match con { + ConstantInfo::Utf8(utf8) => Some(utf8.utf8_string.to_string()), + ConstantInfo::Module(m) => lookup_string(c, m.name_index), + ConstantInfo::Package(p) => lookup_string(c, p.name_index), + _ => None, + } +} + +#[test] +fn module_info() { + let class_bytes = include_bytes!("../java-assets/compiled-classes/module-info.class"); + let class = ClassFile::read(&mut Cursor::new(class_bytes.as_slice())) + .expect("failed to parse module-info.class"); + + // module-info.class should have ACC_MODULE set + assert!( + class + .access_flags + .contains(classfile_parser::ClassAccessFlags::MODULE), + "expected ACC_MODULE flag" + ); + + // Find the Module attribute in the class-level attributes + let module_attr = class + .attributes + .iter() + .find_map(|a| match &a.info_parsed { + Some(AttributeInfoVariant::Module(m)) => Some(m), + _ => None, + }) + .expect("Module attribute not found"); + + // Verify structural fields + assert_eq!(module_attr.module_flags, 0); + assert_eq!(module_attr.module_version_index, 0); + + assert_eq!(module_attr.requires_count, 1); + assert_eq!(module_attr.requires.len(), 1); + let req = &module_attr.requires[0]; + assert_eq!(req.requires_flags, 32768); // ACC_MANDATED + + assert_eq!(module_attr.exports_count, 1); + assert_eq!(module_attr.exports.len(), 1); + let exp = &module_attr.exports[0]; + assert_eq!(exp.exports_flags, 0); + assert_eq!(exp.exports_to_count, 0); + assert_eq!(exp.exports_to_index.len(), 0); + + assert_eq!(module_attr.opens_count, 0); + assert_eq!(module_attr.opens.len(), 0); + assert_eq!(module_attr.uses_count, 0); + assert_eq!(module_attr.uses.len(), 0); + assert_eq!(module_attr.provides_count, 0); + assert_eq!(module_attr.provides.len(), 0); + + // Verify string lookups via constant pool + assert_eq!( + lookup_string(&class, module_attr.module_name_index) + .unwrap() + .as_str(), + "my.module" + ); + assert_eq!( + lookup_string(&class, req.requires_index).unwrap().as_str(), + "java.base" + ); + assert_eq!( + lookup_string(&class, exp.exports_index).unwrap().as_str(), + "com/some" + ); +} + +#[test] +fn module_info_round_trip() { + let original_bytes = + include_bytes!("../java-assets/compiled-classes/module-info.class").to_vec(); + let parsed = ClassFile::read(&mut Cursor::new(original_bytes.as_slice())) + .expect("failed to parse module-info.class"); + + let mut written_bytes = Cursor::new(Vec::new()); + parsed + .write(&mut written_bytes) + .expect("failed to write module-info.class"); + let written_bytes = written_bytes.into_inner(); + + assert_eq!( + original_bytes.len(), + written_bytes.len(), + "written class file has different length: original={}, written={}", + original_bytes.len(), + written_bytes.len() + ); + assert_eq!( + original_bytes, written_bytes, + "written class file bytes differ from original" + ); +} diff --git a/tests/new_attributes.rs b/tests/new_attributes.rs new file mode 100644 index 0000000..a001fe6 --- /dev/null +++ b/tests/new_attributes.rs @@ -0,0 +1,277 @@ +extern crate classfile_parser; + +use std::fs::File; +use std::io::Cursor; +use std::io::prelude::*; + +use binrw::BinWrite; +use binrw::prelude::*; +use classfile_parser::ClassFile; +use classfile_parser::attribute_info::AttributeInfoVariant; + +fn load_class(path: &str) -> ClassFile { + let mut contents: Vec = Vec::new(); + let mut f = File::open(path).unwrap(); + f.read_to_end(&mut contents).unwrap(); + ClassFile::read(&mut Cursor::new(contents)).expect("failed to parse class file") +} + +fn find_attr<'a>( + attrs: &'a [classfile_parser::attribute_info::AttributeInfo], + name: &str, + class: &ClassFile, +) -> Option<&'a AttributeInfoVariant> { + for attr in attrs { + if let Some(ref parsed) = attr.info_parsed { + let attr_name = match &class.const_pool[(attr.attribute_name_index - 1) as usize] { + classfile_parser::constant_info::ConstantInfo::Utf8(u) => u.utf8_string.as_str(), + _ => continue, + }; + if attr_name == name { + return Some(parsed); + } + } + } + None +} + +fn round_trip(path: &str) { + let mut original_bytes: Vec = Vec::new(); + let mut f = File::open(path).unwrap(); + f.read_to_end(&mut original_bytes).unwrap(); + + let parsed = ClassFile::read(&mut Cursor::new(&original_bytes)).expect("failed to parse"); + + let mut written_bytes = Cursor::new(Vec::new()); + parsed.write(&mut written_bytes).expect("failed to write"); + let written_bytes = written_bytes.into_inner(); + + assert_eq!( + original_bytes.len(), + written_bytes.len(), + "round-trip length mismatch for {}: original={}, written={}", + path, + original_bytes.len(), + written_bytes.len() + ); + assert_eq!( + original_bytes, written_bytes, + "round-trip bytes differ for {}", + path + ); +} + +// --- NestMembers (on the outer class) --- + +#[test] +fn nest_members() { + let c = load_class("java-assets/compiled-classes/NestExample.class"); + let attr = find_attr(&c.attributes, "NestMembers", &c) + .expect("NestMembers attribute not found on NestExample"); + + match attr { + AttributeInfoVariant::NestMembers(nm) => { + assert_eq!(nm.number_of_classes, 1); + assert_eq!(nm.classes.len(), 1); + // The single nest member should point to NestExample$Inner + } + other => panic!("Expected NestMembers, got {:?}", other), + } +} + +#[test] +fn nest_members_round_trip() { + round_trip("java-assets/compiled-classes/NestExample.class"); +} + +// --- NestHost (on the inner class) --- + +#[test] +fn nest_host() { + let c = load_class("java-assets/compiled-classes/NestExample$Inner.class"); + let attr = find_attr(&c.attributes, "NestHost", &c) + .expect("NestHost attribute not found on NestExample$Inner"); + + match attr { + AttributeInfoVariant::NestHost(nh) => { + // host_class_index should point to NestExample class constant + assert!(nh.host_class_index > 0); + } + other => panic!("Expected NestHost, got {:?}", other), + } +} + +#[test] +fn nest_host_round_trip() { + round_trip("java-assets/compiled-classes/NestExample$Inner.class"); +} + +// --- Record --- + +#[test] +fn record_attribute() { + let c = load_class("java-assets/compiled-classes/RecordExample.class"); + let attr = find_attr(&c.attributes, "Record", &c) + .expect("Record attribute not found on RecordExample"); + + match attr { + AttributeInfoVariant::Record(rec) => { + assert_eq!(rec.components_count, 2); + assert_eq!(rec.components.len(), 2); + + // First component: int x + let comp0 = &rec.components[0]; + let name0 = match &c.const_pool[(comp0.name_index - 1) as usize] { + classfile_parser::constant_info::ConstantInfo::Utf8(u) => u.utf8_string.as_str(), + _ => panic!("expected Utf8"), + }; + assert_eq!(name0, "x"); + + // Second component: String name + let comp1 = &rec.components[1]; + let name1 = match &c.const_pool[(comp1.name_index - 1) as usize] { + classfile_parser::constant_info::ConstantInfo::Utf8(u) => u.utf8_string.as_str(), + _ => panic!("expected Utf8"), + }; + assert_eq!(name1, "name"); + + // Record components may have Signature sub-attributes that should be interpreted + for comp in &rec.components { + for attr in &comp.attributes { + assert!( + attr.info_parsed.is_some(), + "Record component sub-attribute should have info_parsed populated" + ); + } + } + } + other => panic!("Expected Record, got {:?}", other), + } +} + +#[test] +fn record_round_trip() { + round_trip("java-assets/compiled-classes/RecordExample.class"); +} + +// --- PermittedSubclasses --- + +#[test] +fn permitted_subclasses() { + let c = load_class("java-assets/compiled-classes/SealedExample.class"); + let attr = find_attr(&c.attributes, "PermittedSubclasses", &c) + .expect("PermittedSubclasses attribute not found on SealedExample"); + + match attr { + AttributeInfoVariant::PermittedSubclasses(ps) => { + assert_eq!(ps.number_of_classes, 2); + assert_eq!(ps.classes.len(), 2); + } + other => panic!("Expected PermittedSubclasses, got {:?}", other), + } +} + +#[test] +fn permitted_subclasses_round_trip() { + round_trip("java-assets/compiled-classes/SealedExample.class"); +} + +// --- ModulePackages (byte-level test) --- + +#[test] +fn module_packages_parse() { + use classfile_parser::attribute_info::ModulePackagesAttribute; + + // ModulePackages { package_count: 2, package_index: [5, 10] } + let bytes: Vec = vec![ + 0x00, 0x02, // package_count = 2 + 0x00, 0x05, // package_index[0] = 5 + 0x00, 0x0A, // package_index[1] = 10 + ]; + + let parsed = ModulePackagesAttribute::read(&mut Cursor::new(&bytes)).expect("failed to parse"); + assert_eq!(parsed.package_count, 2); + assert_eq!(parsed.package_index, vec![5, 10]); + + // Round-trip + let mut written = Cursor::new(Vec::new()); + parsed.write(&mut written).expect("failed to write"); + assert_eq!(written.into_inner(), bytes); +} + +// --- ModuleMainClass (byte-level test) --- + +#[test] +fn module_main_class_parse() { + use classfile_parser::attribute_info::ModuleMainClassAttribute; + + // ModuleMainClass { main_class_index: 42 } + let bytes: Vec = vec![0x00, 0x2A]; // 42 + + let parsed = ModuleMainClassAttribute::read(&mut Cursor::new(&bytes)).expect("failed to parse"); + assert_eq!(parsed.main_class_index, 42); + + // Round-trip + let mut written = Cursor::new(Vec::new()); + parsed.write(&mut written).expect("failed to write"); + assert_eq!(written.into_inner(), bytes); +} + +// --- Code sub-attribute interpretation --- + +#[test] +fn code_sub_attributes_are_interpreted() { + // Verify that sub-attributes inside CodeAttribute now have info_parsed populated + let c = load_class("java-assets/compiled-classes/BasicClass.class"); + + for method in &c.methods { + for attr in &method.attributes { + if let Some(AttributeInfoVariant::Code(ref code)) = attr.info_parsed { + for sub_attr in &code.attributes { + assert!( + sub_attr.info_parsed.is_some(), + "Code sub-attribute (name_index={}) should have info_parsed populated", + sub_attr.attribute_name_index + ); + // Verify it's not Unknown + if let Some(AttributeInfoVariant::Unknown(name)) = &sub_attr.info_parsed { + panic!("Code sub-attribute '{}' parsed as Unknown", name); + } + } + } + } + } +} + +#[test] +fn code_sub_attribute_line_number_table() { + // Verify LineNumberTable inside Code is now directly accessible via info_parsed + let c = load_class("java-assets/compiled-classes/BasicClass.class"); + + let mut found_line_number_table = false; + for method in &c.methods { + for attr in &method.attributes { + if let Some(AttributeInfoVariant::Code(ref code)) = attr.info_parsed { + for sub_attr in &code.attributes { + if let Some(AttributeInfoVariant::LineNumberTable(ref lnt)) = + sub_attr.info_parsed + { + found_line_number_table = true; + assert!( + lnt.line_number_table_length > 0, + "LineNumberTable should have entries" + ); + assert_eq!( + lnt.line_number_table.len(), + lnt.line_number_table_length as usize + ); + } + } + } + } + } + assert!( + found_line_number_table, + "Should have found at least one LineNumberTable sub-attribute" + ); +} diff --git a/tests/patching_helpers.rs b/tests/patching_helpers.rs new file mode 100644 index 0000000..699346a --- /dev/null +++ b/tests/patching_helpers.rs @@ -0,0 +1,454 @@ +use std::fs; +use std::io::{Cursor, Read}; +use std::path::Path; +use std::process::Command; + +use binrw::BinWrite; +use binrw::prelude::*; +use classfile_parser::ClassFile; +use classfile_parser::code_attribute::Instruction; +use classfile_parser::constant_info::ConstantInfo; + +// --- Helpers --- + +fn load_basic_class() -> ClassFile { + let mut contents: Vec = Vec::new(); + std::fs::File::open("java-assets/compiled-classes/BasicClass.class") + .unwrap() + .read_to_end(&mut contents) + .unwrap(); + ClassFile::read(&mut Cursor::new(&contents)).expect("failed to parse BasicClass") +} + +fn java_available() -> bool { + Command::new("javac") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + && Command::new("java") + .arg("-version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn compile_and_load( + test_name: &str, + java_src: &str, + class_name: &str, +) -> (std::path::PathBuf, std::path::PathBuf, ClassFile) { + let tmp_dir = std::env::temp_dir().join(format!("classfile_helpers_{}", test_name)); + let _ = fs::remove_dir_all(&tmp_dir); + fs::create_dir_all(&tmp_dir).unwrap(); + + let compile = Command::new("javac") + .arg("-d") + .arg(&tmp_dir) + .arg(java_src) + .output() + .expect("failed to run javac"); + assert!( + compile.status.success(), + "javac failed: {}", + String::from_utf8_lossy(&compile.stderr) + ); + + let class_path = tmp_dir.join(format!("{}.class", class_name)); + let mut class_bytes = Vec::new(); + std::fs::File::open(&class_path) + .expect("failed to open compiled class") + .read_to_end(&mut class_bytes) + .unwrap(); + let class_file = + ClassFile::read(&mut Cursor::new(&class_bytes)).expect("failed to parse class"); + + (tmp_dir, class_path, class_file) +} + +fn write_and_run( + tmp_dir: &Path, + class_path: &Path, + class_file: &ClassFile, + class_name: &str, +) -> String { + let mut out = Cursor::new(Vec::new()); + class_file.write(&mut out).expect("failed to write class"); + fs::write(class_path, out.into_inner()).expect("failed to write class file"); + + let run = Command::new("java") + .arg("-cp") + .arg(tmp_dir) + .arg(class_name) + .output() + .expect("failed to run java"); + assert!( + run.status.success(), + "java failed (exit {}): stderr={}", + run.status, + String::from_utf8_lossy(&run.stderr) + ); + String::from_utf8_lossy(&run.stdout).trim().to_string() +} + +// --- Unit tests: Constant pool helpers --- + +#[test] +fn test_add_utf8() { + let mut cf = load_basic_class(); + let original_len = cf.const_pool.len(); + let idx = cf.add_utf8("test_string_42"); + assert_eq!(idx, (original_len + 1) as u16); + assert_eq!(cf.const_pool.len(), original_len + 1); + assert_eq!(cf.get_utf8(idx), Some("test_string_42")); +} + +#[test] +fn test_get_or_add_utf8_existing() { + let mut cf = load_basic_class(); + // "Code" is always present in any class file + let existing_idx = cf.find_utf8_index("Code").expect("Code should exist"); + let original_len = cf.const_pool.len(); + let idx = cf.get_or_add_utf8("Code"); + assert_eq!(idx, existing_idx); + assert_eq!(cf.const_pool.len(), original_len, "pool should not grow"); +} + +#[test] +fn test_get_or_add_utf8_new() { + let mut cf = load_basic_class(); + let original_len = cf.const_pool.len(); + let idx = cf.get_or_add_utf8("brand_new_entry"); + assert_eq!(idx, (original_len + 1) as u16); + assert_eq!(cf.const_pool.len(), original_len + 1); + assert_eq!(cf.get_utf8(idx), Some("brand_new_entry")); +} + +#[test] +fn test_add_string() { + let mut cf = load_basic_class(); + let original_len = cf.const_pool.len(); + let string_idx = cf.add_string("hello"); + // Should have added 2 entries: Utf8 + String + assert_eq!(cf.const_pool.len(), original_len + 2); + assert_eq!(string_idx, (original_len + 2) as u16); + // The String should point to the Utf8 + let utf8_idx = (original_len + 1) as u16; + match &cf.const_pool[(string_idx - 1) as usize] { + ConstantInfo::String(s) => assert_eq!(s.string_index, utf8_idx), + other => panic!("expected String constant, got {:?}", other), + } + assert_eq!(cf.get_utf8(utf8_idx), Some("hello")); +} + +#[test] +fn test_get_or_add_string_dedup() { + let mut cf = load_basic_class(); + let idx1 = cf.add_string("dedup_me"); + let len_after_first = cf.const_pool.len(); + let idx2 = cf.get_or_add_string("dedup_me"); + assert_eq!(idx1, idx2, "should return same index"); + assert_eq!( + cf.const_pool.len(), + len_after_first, + "pool should not grow on dedup" + ); +} + +#[test] +fn test_add_class() { + let mut cf = load_basic_class(); + let original_len = cf.const_pool.len(); + let class_idx = cf.add_class("com/example/Test"); + assert_eq!(cf.const_pool.len(), original_len + 2); + assert_eq!(class_idx, (original_len + 2) as u16); + match &cf.const_pool[(class_idx - 1) as usize] { + ConstantInfo::Class(c) => { + assert_eq!(cf.get_utf8(c.name_index), Some("com/example/Test")); + } + other => panic!("expected Class constant, got {:?}", other), + } +} + +#[test] +fn test_get_or_add_class_dedup() { + let mut cf = load_basic_class(); + let idx1 = cf.add_class("com/example/Foo"); + let len_after_first = cf.const_pool.len(); + let idx2 = cf.get_or_add_class("com/example/Foo"); + assert_eq!(idx1, idx2, "should return same index"); + assert_eq!( + cf.const_pool.len(), + len_after_first, + "pool should not grow on dedup" + ); +} + +#[test] +fn test_add_name_and_type() { + let mut cf = load_basic_class(); + let original_len = cf.const_pool.len(); + let nat_idx = cf.add_name_and_type("myMethod", "(I)V"); + // get_or_add_utf8 may reuse existing entries; count new entries + assert!(cf.const_pool.len() > original_len); + match &cf.const_pool[(nat_idx - 1) as usize] { + ConstantInfo::NameAndType(nat) => { + assert_eq!(cf.get_utf8(nat.name_index), Some("myMethod")); + assert_eq!(cf.get_utf8(nat.descriptor_index), Some("(I)V")); + } + other => panic!("expected NameAndType constant, got {:?}", other), + } +} + +// --- Unit tests: sync_all --- + +#[test] +fn test_sync_all() { + let mut cf = load_basic_class(); + + // Modify an instruction in a method's code + let method = cf.find_method_mut("").expect("should have "); + let code = method.code_mut().expect("should have code"); + // The constructor should have an Aload0 instruction + assert!( + code.code.iter().any(|i| *i == Instruction::Aload0), + "expected Aload0 in " + ); + + // Add a utf8 constant + cf.add_utf8("sync_all_test"); + + // Call sync_all and verify it doesn't error + cf.sync_all().expect("sync_all should succeed"); + + // Verify counts are correct + assert_eq!(cf.const_pool_size, (cf.const_pool.len() + 1) as u16); + assert_eq!(cf.methods_count, cf.methods.len() as u16); + assert_eq!(cf.fields_count, cf.fields.len() as u16); + assert_eq!(cf.attributes_count, cf.attributes.len() as u16); + + // Round-trip: write and re-parse + let mut out = Cursor::new(Vec::new()); + cf.write(&mut out).expect("failed to write"); + let bytes = out.into_inner(); + let reparsed = + ClassFile::read(&mut Cursor::new(&bytes)).expect("failed to re-parse after sync_all"); + assert_eq!(reparsed.const_pool.len(), cf.const_pool.len()); +} + +// --- Unit tests: with_code --- + +#[test] +fn test_with_code() { + let mut cf = load_basic_class(); + let method = cf.find_method_mut("").expect("should have "); + + let result = method.with_code(|code| { + // Find and count Aload0 instructions + code.code + .iter() + .filter(|i| **i == Instruction::Aload0) + .count() + }); + + match result { + Some(Ok(count)) => assert!(count > 0, "should have found Aload0"), + Some(Err(e)) => panic!("sync failed: {:?}", e), + None => panic!("expected Code attribute"), + } +} + +#[test] +fn test_with_code_none() { + let mut cf = load_basic_class(); + // Add a dummy method with no Code attribute (simulating abstract) + use classfile_parser::method_info::MethodAccessFlags; + use classfile_parser::method_info::MethodInfo; + let name_idx = cf.add_utf8("abstractMethod"); + let desc_idx = cf.get_or_add_utf8("()V"); + cf.methods.push(MethodInfo { + access_flags: MethodAccessFlags::ABSTRACT | MethodAccessFlags::PUBLIC, + name_index: name_idx, + descriptor_index: desc_idx, + attributes_count: 0, + attributes: vec![], + }); + cf.sync_counts(); + + let method = cf.find_method_mut("abstractMethod").expect("should find"); + let result = method.with_code(|_code| ()); + assert!(result.is_none(), "abstract method should return None"); +} + +// --- Unit tests: instruction helpers --- + +#[test] +fn test_find_instruction() { + let cf = load_basic_class(); + let method = cf.find_method("").expect("should have "); + let code = method.code().expect("should have code"); + + let found = code.find_instruction(|i| *i == Instruction::Aload0); + assert!(found.is_some(), "should find Aload0"); + let (idx, instr) = found.unwrap(); + assert_eq!(*instr, Instruction::Aload0); + assert_eq!(idx, 0, "Aload0 should be first instruction in "); +} + +#[test] +fn test_find_instructions() { + let cf = load_basic_class(); + let method = cf.find_method("").expect("should have "); + let code = method.code().expect("should have code"); + + // Find all return-type instructions + let returns = code.find_instructions(|i| matches!(i, Instruction::Return)); + assert!( + !returns.is_empty(), + "should find at least one Return instruction" + ); +} + +#[test] +fn test_replace_instruction() { + let mut cf = load_basic_class(); + let method = cf.find_method_mut("").expect("should have "); + let code = method.code_mut().expect("should have code"); + + // Find Aload0 and replace with Nop + let (idx, _) = code + .find_instruction(|i| *i == Instruction::Aload0) + .expect("should find Aload0"); + code.replace_instruction(idx, Instruction::Nop); + assert_eq!(code.code[idx], Instruction::Nop); +} + +#[test] +fn test_nop_out() { + let mut cf = load_basic_class(); + let method = cf.find_method_mut("").expect("should have "); + let code = method.code_mut().expect("should have code"); + + // Record original code_length by syncing + code.sync_lengths().expect("sync_lengths"); + let original_code_length = code.code_length; + + // nop_out the first 2 instructions + let original_count = code.code.len(); + code.nop_out(0..2).expect("nop_out should succeed"); + + // Sync and verify code_length is preserved + code.sync_lengths().expect("sync_lengths after nop_out"); + assert_eq!( + code.code_length, original_code_length, + "code_length should be preserved after nop_out" + ); + + // Verify the nop'd region is all Nop + // The first N entries (where N = byte size of original 2 instructions) should be Nop + assert!( + code.code.len() >= original_count, + "instruction count should grow or stay same" + ); + for i in &code.code[..code.code.len() - (original_count - 2)] { + assert_eq!(*i, Instruction::Nop, "nop'd region should be all Nop"); + } +} + +// --- E2E tests --- + +/// Rewrite of test_e2e_add_constant_and_redirect_ldc using helpers: +/// add_string + with_code + find_instruction + replace_instruction + sync_all +#[test] +fn test_e2e_helpers_redirect_ldc() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "helpers_ldc", + "java-assets/src/HelloWorld.java", + "HelloWorld", + ); + + // Add "Injected!" string to pool using helper + let string_idx = class_file.add_string("Injected!"); + assert!(string_idx <= 255, "index must fit in u8 for ldc"); + + // Find main and redirect the Ldc + let method = class_file + .find_method_mut("main") + .expect("should have main"); + let result = method.with_code(|code| { + let (idx, _) = code + .find_instruction(|i| matches!(i, Instruction::Ldc(_))) + .expect("should find Ldc"); + code.replace_instruction(idx, Instruction::Ldc(string_idx as u8)); + }); + result + .expect("should have code") + .expect("sync should succeed"); + + class_file.sync_all().expect("sync_all"); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "HelloWorld"); + assert_eq!( + output, "Injected!", + "expected 'Injected!' but got: {}", + output + ); +} + +/// Rewrite of test_e2e_remove_method using nop_out + with_code +#[test] +fn test_e2e_helpers_nop_out() { + if !java_available() { + eprintln!("skipping: javac/java not found"); + return; + } + + let (tmp_dir, class_path, mut class_file) = compile_and_load( + "helpers_nop", + "java-assets/src/SimpleMath.java", + "SimpleMath", + ); + + // Nop out the first 4 instructions of main (the "Integer math:" println + intMath call) + let method = class_file + .find_method_mut("main") + .expect("should have main"); + let result = method.with_code(|code| { + // Verify the instructions we expect + assert!(matches!(&code.code[0], Instruction::Getstatic(_))); + assert!(matches!(&code.code[1], Instruction::Ldc(_))); + assert!(matches!(&code.code[2], Instruction::Invokevirtual(_))); + assert!(matches!(&code.code[3], Instruction::Invokestatic(_))); + code.nop_out(0..4).expect("nop_out should succeed"); + }); + result + .expect("should have code") + .expect("sync should succeed"); + + // Remove intMath method + let int_math_idx = class_file + .methods + .iter() + .position(|m| class_file.get_utf8(m.name_index) == Some("intMath")) + .expect("intMath not found"); + class_file.methods.remove(int_math_idx); + + class_file.sync_all().expect("sync_all"); + + let output = write_and_run(&tmp_dir, &class_path, &class_file, "SimpleMath"); + assert!( + !output.contains("Integer math:"), + "should not contain 'Integer math:', got: {}", + output + ); + assert!( + output.contains("Float math:"), + "expected 'Float math:' in output: {}", + output + ); +} diff --git a/tests/spring_utils.rs b/tests/spring_utils.rs new file mode 100644 index 0000000..0a96196 --- /dev/null +++ b/tests/spring_utils.rs @@ -0,0 +1,757 @@ +#![cfg(feature = "spring-utils")] + +extern crate classfile_parser; + +use std::io::{Cursor, Write}; + +use classfile_parser::jar_utils::{JarFile, JarManifest}; +use classfile_parser::spring_utils::{ + ClasspathIndex, LayersIndex, SpringBootFormat, SpringBootJar, +}; + +use zip::CompressionMethod; +use zip::write::SimpleFileOptions; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Read a .class file from disk as raw bytes. +fn read_class_bytes(name: &str) -> Vec { + let path = format!("java-assets/compiled-classes/{name}"); + std::fs::read(&path).unwrap_or_else(|e| panic!("failed to read {path}: {e}")) +} + +/// Build a manifest for a Spring Boot JAR/WAR. +fn spring_manifest(format: SpringBootFormat, start_class: &str) -> Vec { + let launcher = match format { + SpringBootFormat::Jar => "org.springframework.boot.loader.JarLauncher", + SpringBootFormat::War => "org.springframework.boot.loader.WarLauncher", + }; + let mut m = JarManifest::default_manifest(); + m.set_main_attr("Main-Class", launcher); + m.set_main_attr("Start-Class", start_class); + m.set_main_attr("Spring-Boot-Version", "3.1.0"); + m.set_main_attr("Spring-Boot-Classes", format.classes_dir()); + m.set_main_attr("Spring-Boot-Lib", format.lib_dir()); + m.to_bytes() +} + +/// Build a small JAR in-memory (for nesting inside a fat JAR). +fn build_inner_jar(entries: &[(&str, &[u8])]) -> Vec { + let mut buf = Cursor::new(Vec::new()); + { + let mut writer = zip::ZipWriter::new(&mut buf); + let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated); + for (name, data) in entries { + writer.start_file(*name, options).unwrap(); + writer.write_all(data).unwrap(); + } + writer.finish().unwrap(); + } + buf.into_inner() +} + +/// Build a complete Spring Boot fat JAR in-memory. +/// +/// `app_classes`: entries under `{prefix}/classes/` +/// `nested_jars`: entries under `{prefix}/lib/` (name, jar bytes) +/// `loader_classes`: entries under `org/springframework/boot/loader/` +/// `classpath_idx`: optional classpath.idx content +/// `layers_idx`: optional layers.idx content +fn build_spring_jar( + format: SpringBootFormat, + start_class: &str, + app_classes: &[(&str, &[u8])], + app_resources: &[(&str, &[u8])], + nested_jars: &[(&str, &[u8])], + loader_classes: &[&str], + classpath_idx: Option<&[u8]>, + layers_idx: Option<&[u8]>, +) -> Vec { + let manifest = spring_manifest(format, start_class); + let mut entries: Vec<(String, Vec)> = Vec::new(); + + // Manifest + entries.push(("META-INF/MANIFEST.MF".to_string(), manifest)); + + // App classes + let classes_dir = format.classes_dir(); + for (name, data) in app_classes { + entries.push((format!("{classes_dir}{name}"), data.to_vec())); + } + + // App resources + for (name, data) in app_resources { + entries.push((format!("{classes_dir}{name}"), data.to_vec())); + } + + // Nested JARs + let lib_dir = format.lib_dir(); + for (name, data) in nested_jars { + entries.push((format!("{lib_dir}{name}"), data.to_vec())); + } + + // Loader classes + for name in loader_classes { + entries.push(( + format!("org/springframework/boot/loader/{name}"), + b"fake-loader-class".to_vec(), + )); + } + + // Index files + if let Some(data) = classpath_idx { + entries.push((format!("{}/classpath.idx", format.prefix()), data.to_vec())); + } + if let Some(data) = layers_idx { + entries.push((format!("{}/layers.idx", format.prefix()), data.to_vec())); + } + + // Build the ZIP + let mut buf = Cursor::new(Vec::new()); + { + let mut writer = zip::ZipWriter::new(&mut buf); + let options = SimpleFileOptions::default().compression_method(CompressionMethod::Deflated); + for (name, data) in &entries { + writer.start_file(name, options).unwrap(); + writer.write_all(data).unwrap(); + } + writer.finish().unwrap(); + } + buf.into_inner() +} + +// =========================================================================== +// Detection tests +// =========================================================================== + +#[test] +fn test_detect_jar_format() { + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap(); + assert!(spring.is_some()); + assert_eq!(spring.unwrap().format(), SpringBootFormat::Jar); +} + +#[test] +fn test_detect_war_format() { + let jar_bytes = build_spring_jar( + SpringBootFormat::War, + "com.example.App", + &[], + &[], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap(); + assert!(spring.is_some()); + assert_eq!(spring.unwrap().format(), SpringBootFormat::War); +} + +#[test] +fn test_detect_spring_boot_3_2_launcher() { + // Use the Spring Boot 3.2+ launcher class name + let mut m = JarManifest::default_manifest(); + m.set_main_attr( + "Main-Class", + "org.springframework.boot.loader.launch.JarLauncher", + ); + m.set_main_attr("Start-Class", "com.example.App"); + let manifest_bytes = m.to_bytes(); + + let inner_jar = build_inner_jar(&[("META-INF/MANIFEST.MF", &manifest_bytes)]); + let jar = JarFile::from_bytes(&inner_jar).unwrap(); + let format = classfile_parser::spring_utils::detect_format(&jar); + assert_eq!(format, Some(SpringBootFormat::Jar)); +} + +#[test] +fn test_detect_not_spring_boot() { + // Plain JAR with no Spring Boot launcher + let mut m = JarManifest::default_manifest(); + m.set_main_attr("Main-Class", "com.example.Main"); + let manifest_bytes = m.to_bytes(); + + let inner_jar = build_inner_jar(&[("META-INF/MANIFEST.MF", &manifest_bytes)]); + let result = SpringBootJar::from_bytes(&inner_jar).unwrap(); + assert!(result.is_none()); +} + +#[test] +fn test_detect_no_start_class() { + // Has JarLauncher but no Start-Class → not detected + let mut m = JarManifest::default_manifest(); + m.set_main_attr("Main-Class", "org.springframework.boot.loader.JarLauncher"); + let manifest_bytes = m.to_bytes(); + + let inner_jar = build_inner_jar(&[("META-INF/MANIFEST.MF", &manifest_bytes)]); + let result = SpringBootJar::from_bytes(&inner_jar).unwrap(); + assert!(result.is_none()); +} + +// =========================================================================== +// Manifest shortcut tests +// =========================================================================== + +#[test] +fn test_manifest_shortcuts() { + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.MyApp", + &[], + &[], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + assert_eq!( + spring.start_class().unwrap(), + Some("com.example.MyApp".to_string()) + ); + assert_eq!( + spring.spring_boot_version().unwrap(), + Some("3.1.0".to_string()) + ); + assert_eq!( + spring.spring_boot_classes_path().unwrap(), + Some("BOOT-INF/classes/".to_string()) + ); + assert_eq!( + spring.spring_boot_lib_path().unwrap(), + Some("BOOT-INF/lib/".to_string()) + ); +} + +#[test] +fn test_manifest_shortcuts_missing() { + // Build a minimal manifest without version/paths + let mut m = JarManifest::default_manifest(); + m.set_main_attr("Main-Class", "org.springframework.boot.loader.JarLauncher"); + m.set_main_attr("Start-Class", "com.example.App"); + let manifest_bytes = m.to_bytes(); + + let jar_bytes = build_inner_jar(&[("META-INF/MANIFEST.MF", &manifest_bytes)]); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + assert_eq!( + spring.start_class().unwrap(), + Some("com.example.App".to_string()) + ); + assert_eq!(spring.spring_boot_version().unwrap(), None); + assert_eq!(spring.spring_boot_classes_path().unwrap(), None); + assert_eq!(spring.spring_boot_lib_path().unwrap(), None); +} + +// =========================================================================== +// Entry iteration tests +// =========================================================================== + +#[test] +fn test_app_class_names() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[ + ("com/example/Foo.class", &class_bytes), + ("com/example/Bar.class", &class_bytes), + ], + &[("application.properties", b"key=value")], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let class_names: Vec<&str> = spring.app_class_names().collect(); + assert_eq!(class_names.len(), 2); + assert!( + class_names + .iter() + .all(|n| n.starts_with("BOOT-INF/classes/") && n.ends_with(".class")) + ); +} + +#[test] +fn test_app_resource_names() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[("com/example/Foo.class", &class_bytes)], + &[ + ("application.properties", b"key=value"), + ("static/index.html", b""), + ], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let resources: Vec<&str> = spring.app_resource_names().collect(); + assert_eq!(resources.len(), 2); + assert!(resources.iter().all(|n| !n.ends_with(".class"))); +} + +#[test] +fn test_loader_class_names() { + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[], + &["JarLauncher.class", "LaunchedURLClassLoader.class"], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let loaders: Vec<&str> = spring.loader_class_names().collect(); + assert_eq!(loaders.len(), 2); + assert!( + loaders + .iter() + .all(|n| n.starts_with("org/springframework/boot/loader/")) + ); +} + +#[test] +fn test_nested_jar_names() { + let inner = build_inner_jar(&[("com/lib/A.class", b"fake-class")]); + let inner2 = build_inner_jar(&[("com/lib/B.class", b"fake-class")]); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[("dep-a.jar", &inner), ("dep-b.jar", &inner2)], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let nested: Vec<&str> = spring.nested_jar_names().collect(); + assert_eq!(nested.len(), 2); + assert!( + nested + .iter() + .all(|n| n.starts_with("BOOT-INF/lib/") && n.ends_with(".jar")) + ); +} + +// =========================================================================== +// Nested JAR tests +// =========================================================================== + +#[test] +fn test_open_nested_jar() { + let inner = build_inner_jar(&[ + ("com/lib/Helper.class", b"fake-class"), + ("META-INF/MANIFEST.MF", b"Manifest-Version: 1.0\r\n"), + ]); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[("helper-lib.jar", &inner)], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let nested = spring + .open_nested_jar("BOOT-INF/lib/helper-lib.jar") + .unwrap(); + let names: Vec<&str> = nested.entry_names().collect(); + assert_eq!(names.len(), 2); + assert!(nested.contains_entry("com/lib/Helper.class")); +} + +#[test] +fn test_open_nested_jar_not_found() { + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let result = spring.open_nested_jar("BOOT-INF/lib/nonexistent.jar"); + assert!(result.is_err()); +} + +#[test] +fn test_parse_nested_class() { + let class_bytes = read_class_bytes("BasicClass.class"); + let inner = build_inner_jar(&[("com/dep/BasicClass.class", &class_bytes)]); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[("dep.jar", &inner)], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let class_file = spring + .parse_nested_class("BOOT-INF/lib/dep.jar", "com/dep/BasicClass.class") + .unwrap(); + assert!(class_file.major_version >= 45); +} + +#[test] +fn test_open_all_nested_jars() { + let inner1 = build_inner_jar(&[("A.class", b"fake")]); + let inner2 = build_inner_jar(&[("B.class", b"fake")]); + let inner3 = build_inner_jar(&[("C.class", b"fake")]); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[("a.jar", &inner1), ("b.jar", &inner2), ("c.jar", &inner3)], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let all = spring.open_all_nested_jars(); + assert_eq!(all.len(), 3); + assert!(all.iter().all(|(_, r)| r.is_ok())); +} + +// =========================================================================== +// App class tests +// =========================================================================== + +#[test] +fn test_parse_app_class() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[("com/example/BasicClass.class", &class_bytes)], + &[], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let cf = spring + .parse_app_class("BOOT-INF/classes/com/example/BasicClass.class") + .unwrap(); + assert!(cf.major_version >= 45); +} + +#[test] +fn test_parse_all_app_classes() { + let class_bytes = read_class_bytes("BasicClass.class"); + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[ + ("com/example/A.class", &class_bytes), + ("com/example/B.class", &class_bytes), + ], + &[("application.properties", b"key=val")], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let results = spring.parse_all_app_classes(); + assert_eq!(results.len(), 2); + assert!(results.iter().all(|(_, r)| r.is_ok())); +} + +// =========================================================================== +// ClasspathIndex tests +// =========================================================================== + +#[test] +fn test_classpath_index_parse() { + let data = b"- \"BOOT-INF/lib/spring-core.jar\"\n- \"BOOT-INF/lib/spring-web.jar\"\n"; + let idx = ClasspathIndex::parse(data).unwrap(); + + assert_eq!(idx.len(), 2); + assert_eq!(idx.entries()[0], "BOOT-INF/lib/spring-core.jar"); + assert_eq!(idx.entries()[1], "BOOT-INF/lib/spring-web.jar"); + assert!(idx.contains("BOOT-INF/lib/spring-core.jar")); + assert!(!idx.contains("BOOT-INF/lib/nonexistent.jar")); + assert!(!idx.is_empty()); +} + +#[test] +fn test_classpath_index_round_trip() { + let data = b"- \"BOOT-INF/lib/a.jar\"\n- \"BOOT-INF/lib/b.jar\"\n- \"BOOT-INF/lib/c.jar\"\n"; + let idx = ClasspathIndex::parse(data).unwrap(); + let serialized = idx.to_bytes(); + assert_eq!(serialized, data.to_vec()); +} + +#[test] +fn test_classpath_index_from_jar() { + let cp_data = b"- \"BOOT-INF/lib/dep-a.jar\"\n- \"BOOT-INF/lib/dep-b.jar\"\n"; + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[], + &[], + Some(cp_data), + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let idx = spring.classpath_index().unwrap().unwrap(); + assert_eq!(idx.len(), 2); + assert!(idx.contains("BOOT-INF/lib/dep-a.jar")); +} + +#[test] +fn test_classpath_index_missing() { + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.App", + &[], + &[], + &[], + &[], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + let idx = spring.classpath_index().unwrap(); + assert!(idx.is_none()); +} + +// =========================================================================== +// LayersIndex tests +// =========================================================================== + +#[test] +fn test_layers_index_parse() { + let data = b"- \"dependencies\":\n - \"BOOT-INF/lib/\"\n- \"application\":\n - \"BOOT-INF/classes/\"\n - \"BOOT-INF/lib/app-dep.jar\"\n"; + let idx = LayersIndex::parse(data).unwrap(); + + assert_eq!(idx.len(), 2); + assert_eq!(idx.layers()[0].name, "dependencies"); + assert_eq!(idx.layers()[0].paths, vec!["BOOT-INF/lib/"]); + assert_eq!(idx.layers()[1].name, "application"); + assert_eq!(idx.layers()[1].paths.len(), 2); +} + +#[test] +fn test_layers_index_round_trip() { + let data = "- \"dependencies\":\n - \"BOOT-INF/lib/\"\n- \"application\":\n - \"BOOT-INF/classes/\"\n"; + let idx = LayersIndex::parse(data.as_bytes()).unwrap(); + let serialized = idx.to_bytes(); + assert_eq!(String::from_utf8(serialized).unwrap(), data); +} + +#[test] +fn test_layers_index_empty_layer() { + let data = b"- \"empty-layer\":\n- \"nonempty\":\n - \"BOOT-INF/classes/\"\n"; + let idx = LayersIndex::parse(data).unwrap(); + + assert_eq!(idx.len(), 2); + assert!(idx.find_layer("empty-layer").unwrap().paths.is_empty()); + assert_eq!(idx.find_layer("nonempty").unwrap().paths.len(), 1); +} + +#[test] +fn test_layers_index_find_layer() { + let data = b"- \"deps\":\n - \"BOOT-INF/lib/\"\n- \"app\":\n - \"BOOT-INF/classes/\"\n"; + let idx = LayersIndex::parse(data).unwrap(); + + assert!(idx.find_layer("deps").is_some()); + assert!(idx.find_layer("app").is_some()); + assert!(idx.find_layer("nonexistent").is_none()); + + let names: Vec<&str> = idx.layer_names().collect(); + assert_eq!(names, vec!["deps", "app"]); +} + +#[test] +fn test_layers_index_layer_for_path() { + let data = b"- \"dependencies\":\n - \"BOOT-INF/lib/\"\n- \"application\":\n - \"BOOT-INF/classes/\"\n"; + let idx = LayersIndex::parse(data).unwrap(); + + assert_eq!( + idx.layer_for_path("BOOT-INF/lib/spring-core.jar"), + Some("dependencies") + ); + assert_eq!( + idx.layer_for_path("BOOT-INF/classes/com/example/App.class"), + Some("application") + ); + assert_eq!(idx.layer_for_path("META-INF/MANIFEST.MF"), None); +} + +// =========================================================================== +// WAR variant test +// =========================================================================== + +#[test] +fn test_war_variant() { + let class_bytes = read_class_bytes("BasicClass.class"); + let inner = build_inner_jar(&[("com/lib/A.class", b"fake")]); + let jar_bytes = build_spring_jar( + SpringBootFormat::War, + "com.example.WarApp", + &[("com/example/Svc.class", &class_bytes)], + &[("application.yml", b"server:\n port: 8080\n")], + &[("dep.jar", &inner)], + &["JarLauncher.class"], + None, + None, + ); + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + assert_eq!(spring.format(), SpringBootFormat::War); + assert_eq!( + spring.start_class().unwrap(), + Some("com.example.WarApp".to_string()) + ); + + let classes: Vec<&str> = spring.app_class_names().collect(); + assert_eq!(classes.len(), 1); + assert!(classes[0].starts_with("WEB-INF/classes/")); + + let resources: Vec<&str> = spring.app_resource_names().collect(); + assert_eq!(resources.len(), 1); + + let nested: Vec<&str> = spring.nested_jar_names().collect(); + assert_eq!(nested.len(), 1); + assert!(nested[0].starts_with("WEB-INF/lib/")); +} + +// =========================================================================== +// Integration test +// =========================================================================== + +#[test] +fn test_full_fat_jar_analysis() { + let class_bytes = read_class_bytes("BasicClass.class"); + let factorial_bytes = read_class_bytes("Factorial.class"); + + // Build nested JARs with real classes + let dep_jar = build_inner_jar(&[("com/dep/BasicClass.class", &class_bytes)]); + let util_jar = build_inner_jar(&[("com/util/Factorial.class", &factorial_bytes)]); + + let cp_idx = b"- \"BOOT-INF/lib/dep.jar\"\n- \"BOOT-INF/lib/util.jar\"\n"; + let layers_idx = "- \"dependencies\":\n - \"BOOT-INF/lib/\"\n- \"application\":\n - \"BOOT-INF/classes/\"\n"; + + let jar_bytes = build_spring_jar( + SpringBootFormat::Jar, + "com.example.Main", + &[ + ("com/example/Main.class", &class_bytes), + ("com/example/Service.class", &factorial_bytes), + ], + &[("application.properties", b"spring.application.name=test")], + &[("dep.jar", &dep_jar), ("util.jar", &util_jar)], + &["JarLauncher.class", "LaunchedURLClassLoader.class"], + Some(cp_idx), + Some(layers_idx.as_bytes()), + ); + + let spring = SpringBootJar::from_bytes(&jar_bytes).unwrap().unwrap(); + + // Format + assert_eq!(spring.format(), SpringBootFormat::Jar); + + // Manifest + assert_eq!( + spring.start_class().unwrap(), + Some("com.example.Main".to_string()) + ); + assert_eq!( + spring.spring_boot_version().unwrap(), + Some("3.1.0".to_string()) + ); + + // App classes + let app_classes: Vec<&str> = spring.app_class_names().collect(); + assert_eq!(app_classes.len(), 2); + let parsed = spring.parse_all_app_classes(); + assert_eq!(parsed.len(), 2); + assert!(parsed.iter().all(|(_, r)| r.is_ok())); + + // App resources + let resources: Vec<&str> = spring.app_resource_names().collect(); + assert_eq!(resources.len(), 1); + + // Loader classes + let loaders: Vec<&str> = spring.loader_class_names().collect(); + assert_eq!(loaders.len(), 2); + + // Nested JARs + let nested: Vec<&str> = spring.nested_jar_names().collect(); + assert_eq!(nested.len(), 2); + + let all_nested = spring.open_all_nested_jars(); + assert_eq!(all_nested.len(), 2); + assert!(all_nested.iter().all(|(_, r)| r.is_ok())); + + // Parse class from nested JAR + let cf = spring + .parse_nested_class("BOOT-INF/lib/dep.jar", "com/dep/BasicClass.class") + .unwrap(); + assert!(cf.major_version >= 45); + + // Classpath index + let cp = spring.classpath_index().unwrap().unwrap(); + assert_eq!(cp.len(), 2); + assert!(cp.contains("BOOT-INF/lib/dep.jar")); + assert!(cp.contains("BOOT-INF/lib/util.jar")); + + // Layers index + let layers = spring.layers_index().unwrap().unwrap(); + assert_eq!(layers.len(), 2); + assert_eq!( + layers.layer_for_path("BOOT-INF/lib/dep.jar"), + Some("dependencies") + ); + assert_eq!( + layers.layer_for_path("BOOT-INF/classes/com/example/Main.class"), + Some("application") + ); +}