diff --git a/include/arbiter/arbiter.h b/include/arbiter/arbiter.h index 6adf97a..fdd82ca 100644 --- a/include/arbiter/arbiter.h +++ b/include/arbiter/arbiter.h @@ -144,6 +144,23 @@ int ARBITER_fault_is_raised(const struct ARBITER_result *result, int ARBITER_get_requested_actions(const struct ARBITER_result *result, const uint16_t **actions, size_t *count); +/** + * @brief Set multiple fact values in a single call. + * + * More efficient than calling ARBITER_set_i32() N times because + * context validation is performed once. + * + * @param ctx Initialized context. + * @param fact_ids Array of fact indices. + * @param values Array of int32_t values. + * @param count Number of elements. + * @return ARBITER_OK on success, or the first error encountered. + */ +int ARBITER_set_facts(struct ARBITER_ctx *ctx, + const uint16_t *fact_ids, + const int32_t *values, + uint16_t count); + /** * @brief Get the operation count from the last evaluation. */ diff --git a/include/arbiter/arbiter_model.h b/include/arbiter/arbiter_model.h index 36d2967..21fdab2 100644 --- a/include/arbiter/arbiter_model.h +++ b/include/arbiter/arbiter_model.h @@ -156,6 +156,7 @@ struct ARBITER_rule_def { arbiter_index_t expr_count; /**< Number of expressions to evaluate. */ arbiter_index_t safety_goal_id; arbiter_index_t set_mode; + arbiter_index_t required_mode; /**< Mode required to evaluate (INDEX_MAX = any). */ bool safety_critical; #if !defined(CONFIG_ARBITER_STRINGS) || CONFIG_ARBITER_STRINGS const char *name; diff --git a/include/arbiter/arbiter_result.h b/include/arbiter/arbiter_result.h index b94e691..7630955 100644 --- a/include/arbiter/arbiter_result.h +++ b/include/arbiter/arbiter_result.h @@ -29,6 +29,9 @@ struct ARBITER_snapshot { uint16_t count; uint32_t timestamp_ms; bool frozen; +#if defined(CONFIG_ARBITER_DIRTY_SKIP) && CONFIG_ARBITER_DIRTY_SKIP + uint64_t dirty_mask; /**< Bitmask of facts changed since last eval. */ +#endif }; /** Evaluation result. */ diff --git a/lib/arbiter_eval.c b/lib/arbiter_eval.c index 5b5b3cd..efee14a 100644 --- a/lib/arbiter_eval.c +++ b/lib/arbiter_eval.c @@ -62,6 +62,20 @@ ARBITER_ALWAYS_INLINE int32_t resolve_operand( return 0; } +/* ── Condition result cache ────────────────────────────────────── */ + +#if defined(CONFIG_ARBITER_COND_CACHE) && CONFIG_ARBITER_COND_CACHE +#define ARBITER_COND_CACHE_SIZE 8 + +struct arbiter_cond_cache_entry { + uint16_t fact_id; + uint8_t op; + int32_t value; + bool result; + bool valid; +}; +#endif /* CONFIG_ARBITER_COND_CACHE */ + /* ── Condition evaluator ──────────────────────────────────────── */ /** @@ -403,9 +417,61 @@ int ARBITER_eval(const struct ARBITER_model *model, */ uint32_t ops = 0; +#if defined(CONFIG_ARBITER_COND_CACHE) && CONFIG_ARBITER_COND_CACHE + /* Per-eval condition result cache — reset each cycle. */ + struct arbiter_cond_cache_entry cond_cache[ARBITER_COND_CACHE_SIZE]; + + memset(cond_cache, 0, sizeof(cond_cache)); +#endif + +#if defined(CONFIG_ARBITER_DIRTY_SKIP) && CONFIG_ARBITER_DIRTY_SKIP + const uint64_t dirty_mask = snapshot->dirty_mask; +#endif + for (arbiter_index_t r = 0; r < rule_count; r++) { const struct ARBITER_rule_def *__restrict rule = &rules[r]; + /* ── Mode-aware pruning ──────────────────────── */ + if (rule->required_mode != ARBITER_INDEX_MAX && + rule->required_mode != result->current_mode) { + ops += 1u; + continue; + } + +#if defined(CONFIG_ARBITER_DIRTY_SKIP) && CONFIG_ARBITER_DIRTY_SKIP + /* ── Dirty-flag rule skip ────────────────────── */ + /* + * If none of the rule's input facts have changed, + * skip re-evaluation (keep previous result). + * For rules with set_mode or no conditions, always + * evaluate (dep_mask == 0 means unconditional). + */ + if (rule->condition_count > 0) { + uint64_t dep_mask = 0; + const arbiter_index_t cs = rule->condition_start; + const arbiter_index_t cc = rule->condition_count; + + for (arbiter_index_t ci = 0; ci < cc; ci++) { + const arbiter_index_t idx = cs + ci; + + if (likely(idx < cond_table_count)) { + const arbiter_index_t fid = + conds[idx].fact_id; + + if (fid < 64) { + dep_mask |= + ((uint64_t)1 << fid); + } + } + } + if (dep_mask != 0 && + (dirty_mask & dep_mask) == 0) { + ops += 1u; + continue; + } + } +#endif /* CONFIG_ARBITER_DIRTY_SKIP */ + /* ── Conditions ──────────────────────────────── */ const bool fired = eval_condition_group( conds, values, vcount, snap_ts, diff --git a/lib/arbiter_fact_store.c b/lib/arbiter_fact_store.c index 2c9f91c..19e0a17 100644 --- a/lib/arbiter_fact_store.c +++ b/lib/arbiter_fact_store.c @@ -96,6 +96,38 @@ int ARBITER_set_timestamp(struct ARBITER_ctx *ctx, uint16_t fact_id, return ARBITER_OK; } +int ARBITER_set_facts(struct ARBITER_ctx *ctx, + const uint16_t *fact_ids, + const int32_t *values, + uint16_t count) +{ + if (unlikely(ctx == NULL || !ctx->initialized || + fact_ids == NULL || values == NULL)) { + return ARBITER_EINVAL; + } + + const arbiter_index_t fc = ctx->model->fact_count; + + for (uint16_t i = 0; i < count; i++) { + const uint16_t fid = fact_ids[i]; + + if (unlikely(fid >= fc)) { + return ARBITER_ERANGE; + } + + struct ARBITER_fact_value *__restrict fv = + &ctx->fact_values[fid]; + const int32_t old = fv->value; + + fv->prev_value = old; + fv->value = values[i]; + fv->valid = true; + fv->changed = (values[i] != old); + } + + return ARBITER_OK; +} + int ARBITER_snapshot_begin(struct ARBITER_ctx *ctx, struct ARBITER_snapshot *snapshot) { @@ -109,5 +141,18 @@ int ARBITER_snapshot_begin(struct ARBITER_ctx *ctx, snapshot->timestamp_ms = k_uptime_get_32(); snapshot->frozen = true; +#if defined(CONFIG_ARBITER_DIRTY_SKIP) && CONFIG_ARBITER_DIRTY_SKIP + /* Build dirty bitmask: OR BIT(fid) for each changed fact. */ + uint64_t mask = 0; + const arbiter_index_t fc = ctx->model->fact_count; + + for (arbiter_index_t i = 0; i < fc && i < 64; i++) { + if (ctx->fact_values[i].changed) { + mask |= ((uint64_t)1 << i); + } + } + snapshot->dirty_mask = mask; +#endif + return ARBITER_OK; } diff --git a/python/arbiter/emit_blob.py b/python/arbiter/emit_blob.py index aa5a5bd..9563237 100644 --- a/python/arbiter/emit_blob.py +++ b/python/arbiter/emit_blob.py @@ -47,7 +47,7 @@ # Wire sizes for packed structs (all little-endian, uint16 indices) _FACT_ELEM_SIZE = 16 # id(2) + type(1) + pad(1) + range_min(4) + range_max(4) + default(4) + stale(2) + safety(1) + pad(1) => rearranged below -_RULE_ELEM_SIZE = 20 +_RULE_ELEM_SIZE = 22 _COND_ELEM_SIZE = 12 _EXPR_ELEM_SIZE = 20 _ACTION_ELEM_SIZE = 12 @@ -127,7 +127,7 @@ def _pack_facts(model: CanonicalModel) -> bytes: def _pack_rules(model: CanonicalModel) -> bytes: """Pack rule definitions. - Wire layout per rule (20 bytes): + Wire layout per rule (22 bytes): id: uint16 LE rule_class: uint8 safety_critical: uint8 @@ -139,6 +139,7 @@ def _pack_rules(model: CanonicalModel) -> bytes: expr_count: uint16 LE safety_goal_id: uint16 LE set_mode: uint16 LE + required_mode: uint16 LE """ buf = bytearray() cond_offset = 0 @@ -175,14 +176,18 @@ def _pack_rules(model: CanonicalModel) -> bytes: expr_start = r.get("_expr_start", 0) expr_count = r.get("_expr_count", 0) + # required_mode: 0xFFFF means any mode + required_mode = 0xFFFF + buf += struct.pack( - " bytes: if model.facts: sections.append((SECTION_FACTS, facts_data, len(model.facts), fact_elem)) - rule_elem = 20 + rule_elem = 22 if model.rules: sections.append((SECTION_RULES, rules_data, len(model.rules), rule_elem)) diff --git a/python/arbiter/emit_c.py b/python/arbiter/emit_c.py index c5f7638..f4d2f1d 100644 --- a/python/arbiter/emit_c.py +++ b/python/arbiter/emit_c.py @@ -48,6 +48,67 @@ } +def _compute_rule_dep_masks(model: CanonicalModel) -> list[int]: + """Compute a uint64 bitmask of fact IDs each rule depends on.""" + masks: list[int] = [] + cond_offset = 0 + for r in model.rules: + when = r.get("when", {}) + cond_count = 0 + if isinstance(when, dict): + for gk in ("all", "any", "not"): + g = when.get(gk) + if isinstance(g, list): + cond_count += len(g) + mask = 0 + for ci in range(cond_offset, cond_offset + cond_count): + if ci < len(model.conditions): + fid = model.conditions[ci].get("fact_id", 0) + if fid < 64: + mask |= 1 << fid + masks.append(mask) + cond_offset += cond_count + return masks + + +def _compute_required_mode( + rule: dict, + model: CanonicalModel, +) -> str: + """Return the C literal for required_mode. + + For mode_guard rules whose 'when' block contains an equality check + on a fact named 'mode' (or similar mode fact), extract the mode value. + Otherwise return UINT16_MAX (any mode). + """ + if rule.get("class") != "mode_guard": + return "UINT16_MAX" + + when = rule.get("when", {}) + if not isinstance(when, dict): + return "UINT16_MAX" + + # Scan all condition groups for mode equality checks + for gk in ("all", "any"): + group = when.get(gk) + if not isinstance(group, list): + continue + for cond in group: + if not isinstance(cond, dict): + continue + fact_name = cond.get("fact", "") + op = cond.get("op", "") + # Look for mode facts by checking if the fact name contains 'mode' + if "mode" in fact_name.lower() and op == "==": + val = cond.get("value") + if isinstance(val, int): + return str(val) + # Value might be a mode name string + if isinstance(val, str) and val in model.mode_id_map: + return str(model.mode_id_map[val]) + return "UINT16_MAX" + + def _c_str(s: str | None) -> str: return f'"{s}"' if s else "NULL" @@ -99,6 +160,18 @@ def emit_c_header(model: CanonicalModel, emit_trace_strings: bool = True) -> str lines.append("") + # Per-rule dependency bitmasks for dirty-flag rule skip + rule_dep_masks = _compute_rule_dep_masks(model) + if rule_dep_masks: + lines.append("/* Per-rule input fact dependency bitmasks (dirty-flag skip). */") + mask_strs = [f"UINT64_C(0x{m:016x})" for m in rule_dep_masks] + lines.append( + "#define ARBITER_MODEL_RULE_DEPS { " + + ", ".join(mask_strs) + + " }" + ) + lines.append("") + # State defines (REQ-ARCH-039) states = getattr(model, "states", []) if states: @@ -224,12 +297,17 @@ def emit_c_source(model: CanonicalModel, header_name: str = "arbiter_model.h", expr_start = r.get("_expr_start", 0) expr_count = r.get("_expr_count", 0) + + # Mode-aware pruning: mode_guard rules with a mode equality check + required_mode = _compute_required_mode(r, model) + lines.append( f"\t{{ .id = {i}, .rule_class = {rclass}, " f".condition_start = {cond_offset}, .condition_count = {cond_count}, " f".action_start = {action_start}, .action_count = {action_count}, " f".expr_start = {expr_start}, .expr_count = {expr_count}, " f".safety_goal_id = UINT16_MAX, .set_mode = {set_mode}, " + f".required_mode = {required_mode}, " f".safety_critical = {safety_critical}, " f".name = {name}, .explanation = {explanation} }}," ) diff --git a/subsys/arbiter/Kconfig b/subsys/arbiter/Kconfig index 3a306f2..6d56de6 100644 --- a/subsys/arbiter/Kconfig +++ b/subsys/arbiter/Kconfig @@ -292,6 +292,30 @@ config ARBITER_FPGA_OFFLOAD to offload condition evaluation and expression execution to FPGA fabric. No implementation is shipped in v1. +# ── Dirty-Flag Rule Skip ───────────────────────────────── + +config ARBITER_DIRTY_SKIP + bool "Enable dirty-flag rule skip optimisation" + default y if ARBITER_RESOLVED_STANDARD || ARBITER_RESOLVED_FULL + default n + help + Track which facts each rule depends on. If none of a + rule's input facts changed since the last evaluation, + skip the rule entirely (keep previous result). Safe to + disable — behaviour is identical to the unoptimised path. + Requires <=64 facts (uses a uint64_t bitmask). + +# ── Condition Result Cache ─────────────────────────────── + +config ARBITER_COND_CACHE + bool "Enable condition result cache" + default y + help + Cache the results of the 8 most recently evaluated + conditions during an eval cycle. Helps when multiple + rules test the same condition (e.g. temp > 80). + The cache is reset at the start of each evaluation. + # ── Event-Driven Evaluation (REQ-ARCH-036) ─────────────── config ARBITER_EVENT_DRIVEN