数位DP用于解决数字各位相关的计数问题,例如统计区间内满足特定条件的数字数量。其核心是通过动态规划逐位处理数字,利用记忆化技术避免重复计算。
- 拆解数位:将数字转换为字符数组,逐位处理。
- 状态记录:记录当前位置、是否受上界限制、前导零状态及其他条件。
- 记忆化搜索:缓存已计算的状态,优化时间复杂度。
- 预处理数位:将数字转换为字符串或数组。
- 递归处理每一位:
- 限制条件:当前位是否受上界限制。
- 前导零处理:标记是否处于前导零状态。
- 状态转移:根据当前位选择更新状态。
- 边界处理:处理完所有位后返回结果。
from functools import lru_cache
def count_special_numbers(n: int) -> int:
s = str(n)
@lru_cache(maxsize=None)
def dp(pos: int, mask: int, tight: bool, lead: bool) -> int:
if pos == len(s):
return 0 if lead else 1
limit = int(s[pos]) if tight else 9
total = 0
for d in range(0, limit + 1):
new_tight = tight and (d == limit)
new_lead = lead and (d == 0)
if new_lead:
total += dp(pos + 1, mask, new_tight, new_lead)
else:
if (mask & (1 << d)) == 0:
new_mask = mask | (1 << d)
total += dp(pos + 1, new_mask, new_tight, new_lead)
return total
return dp(0, 0, True, True)
# 示例:统计1到n中无重复数字的数目
print(count_special_numbers(20)) # 输出19(1-20中除11外都符合)package main
import (
"fmt"
"strconv"
)
func countSpecialNumbers(n int) int {
s := strconv.Itoa(n)
m := len(s)
memo := make([][1 << 10]int, m)
for i := range memo {
for j := range memo[i] {
memo[i][j] = -1 // -1 表示没有计算过
}
}
var dfs func(int, int, bool, bool) int
dfs = func(i, mask int, isLimit, isNum bool) (res int) {
if i == m {
if isNum {
return 1 // 得到了一个合法数字
}
return
}
if !isLimit && isNum {
p := &memo[i][mask]
if *p >= 0 { // 之前计算过
return *p
}
defer func() { *p = res }() // 记忆化
}
if !isNum { // 可以跳过当前数位
res += dfs(i+1, mask, false, false)
}
d := 0
if !isNum {
d = 1 // 如果前面没有填数字,必须从 1 开始(因为不能有前导零)
}
up := 9
if isLimit {
up = int(s[i] - '0') // 如果前面填的数字都和 n 的一样,那么这一位至多填数字 s[i](否则就超过 n 啦)
}
for ; d <= up; d++ { // 枚举要填入的数字 d
if mask>>d&1 == 0 { // d 不在 mask 中,说明之前没有填过 d
res += dfs(i+1, mask|1<<d, isLimit && d == up, true)
}
}
return
}
return dfs(0, 0, true, false)
}| 参数 | 说明 |
|---|---|
pos |
当前处理的数位位置(从高位到低位)。 |
mask |
状态掩码,记录已使用的数字(例如用位掩码表示)。 |
tight |
是否受上界限制(如处理到第i位时,前i-1位是否与上界相同)。 |
lead |
是否处于前导零状态(前导零不计入已使用数字)。 |
- 无重复数字计数:如示例所示。
- 数位和限制:统计数位和等于特定值的数字。
- 特定模式匹配:如包含/不包含某些子序列。
通过合理设计状态转移和记忆化策略,数位DP能高效解决复杂的数位计数问题。模板可根据具体问题调整状态定义和转移逻辑。
from functools import cache
class Solution:
def numberOfPowerfulInt(self, start: int, finish: int, limit: int, s: str) -> int:
high = list(map(int, str(finish))) # 避免在 dfs 中频繁调用 int()
n = len(high)
low = list(map(int, str(start).zfill(n))) # 补前导零,和 high 对齐
diff = n - len(s)
@cache
def dfs(i: int, limit_low: bool, limit_high: bool) -> int:
if i == n:
return 1
# 第 i 个数位可以从 lo 枚举到 hi
# 如果对数位还有其它约束,应当只在下面的 for 循环做限制,不应修改 lo 或 hi
lo = low[i] if limit_low else 0
hi = high[i] if limit_high else 9
res = 0
if i < diff: # 枚举这个数位填什么
for d in range(lo, min(hi, limit) + 1):
res += dfs(i + 1, limit_low and d == lo, limit_high and d == hi)
else: # 这个数位只能填 s[i-diff]
x = int(s[i - diff])
if lo <= x <= hi: # 题目保证 x <= limit,无需判断
res = dfs(i + 1, limit_low and x == lo, limit_high and x == hi)
return res
return dfs(0, True, True)package main
func numberOfPowerfulInt(start, finish int64, limit int, s string) int64 {
low := strconv.FormatInt(start, 10)
high := strconv.FormatInt(finish, 10)
n := len(high)
low = strings.Repeat("0", n-len(low)) + low // 补前导零,和 high 对齐
diff := n - len(s)
memo := make([]int64, n)
for i := range memo {
memo[i] = -1
}
var dfs func(int, bool, bool) int64
dfs = func(i int, limitLow, limitHigh bool) (res int64) {
if i == n {
return 1
}
if !limitLow && !limitHigh {
p := &memo[i]
if *p >= 0 {
return *p
}
defer func() { *p = res }()
}
// 第 i 个数位可以从 lo 枚举到 hi
// 如果对数位还有其它约束,应当只在下面的 for 循环做限制,不应修改 lo 或 hi
lo := 0
if limitLow {
lo = int(low[i] - '0')
}
hi := 9
if limitHigh {
hi = int(high[i] - '0')
}
if i < diff { // 枚举这个数位填什么
for d := lo; d <= min(hi, limit); d++ {
res += dfs(i+1, limitLow && d == lo, limitHigh && d == hi)
}
} else { // 这个数位只能填 s[i-diff]
x := int(s[i-diff] - '0')
if lo <= x && x <= hi { // 题目保证 x <= limit,无需判断
res += dfs(i+1, limitLow && x == lo, limitHigh && x == hi)
}
}
return
}
return dfs(0, true, true)
}# 代码示例:返回 [low, high] 中的恰好包含 target 个 0 的数字个数
# 比如 digitDP(0, 10, 1) == 2
# 要点:我们统计的是 0 的个数,需要区分【前导零】和【数字中的零】,前导零不能计入,而数字中的零需要计入
def digitDP(low: int, high: int, target: int) -> int:
low_s = list(map(int, str(low))) # 避免在 dfs 中频繁调用 int()
high_s = list(map(int, str(high)))
n = len(high_s)
diff_lh = n - len(low_s)
@cache
def dfs(i: int, cnt0: int, limit_low: bool, limit_high: bool) -> int:
if cnt0 > target:
return 0 # 不合法
if i == n:
return 1 if cnt0 == target else 0
lo = low_s[i - diff_lh] if limit_low and i >= diff_lh else 0
hi = high_s[i] if limit_high else 9
res = 0
start = lo
# 通过 limit_low 和 i 可以判断能否不填数字,无需 is_num 参数
# 如果前导零不影响答案,去掉这个 if block
if limit_low and i < diff_lh:
# 不填数字,上界不受约束
res = dfs(i + 1, 0, True, False)
start = 1
for d in range(start, hi + 1):
res += dfs(i + 1,
cnt0 + (1 if d == 0 else 0), # 统计 0 的个数
limit_low and d == lo,
limit_high and d == hi)
# res %= MOD
return res
return dfs(0, 0, True, True)// 代码示例:返回 [low, high] 中的恰好包含 target 个 0 的数字个数
// 比如 digitDP(0, 10, 1) == 2
// 要点:我们统计的是 0 的个数,需要区分【前导零】和【数字中的零】,前导零不能计入,而数字中的零需要计入
long long digitDP(long long low, long long high, int target) {
string low_s = to_string(low);
string high_s = to_string(high);
int n = high_s.size();
int diff_lh = n - low_s.size();
vector memo(n, vector<long long>(target + 1, -1));
auto dfs = [&](this auto&& dfs, int i, int cnt0, bool limit_low, bool limit_high) -> long long {
if (cnt0 > target) {
return 0; // 不合法
}
if (i == n) {
return cnt0 == target;
}
if (!limit_low && !limit_high && memo[i][cnt0] >= 0) {
return memo[i][cnt0];
}
int lo = limit_low && i >= diff_lh ? low_s[i - diff_lh] - '0' : 0;
int hi = limit_high ? high_s[i] - '0' : 9;
long long res = 0;
int d = lo;
// 通过 limit_low 和 i 可以判断能否不填数字,无需 is_num 参数
// 如果前导零不影响答案,去掉这个 if block
if (limit_low && i < diff_lh) {
// 不填数字,上界不受约束
res = dfs(i + 1, 0, true, false);
d = 1;
}
for (; d <= hi; d++) {
// 统计 0 的个数
res += dfs(i + 1, cnt0 + (d == 0), limit_low && d == lo, limit_high && d == hi);
// res %= MOD;
}
if (!limit_low && !limit_high) {
memo[i][cnt0] = res;
}
return res;
};
return dfs(0, 0, true, true);
}