ABC 161D Lunlun Number

本番では解けなかった問題。敗因は全列挙できることに気づけなかったこと。

別解ができたので貼っておく。

考えたこと

ルンルン数が十分密に存在するなら1から順に判定していけるかと思った。しかし次のようなコードでルンルン数の総数を調べたところ無理そうだとわかった。 今になって雑に見積もってみると、解説の解法から考察すると、上限が一桁増えるとルンルン数の総数は大体3倍くらいに増える。 そのため107まで調べても1万個に満たないルンルン数しか発見できず、この解法では間に合わない。

fn main() {
    let n = 100_000_000;
    let mut count = 0;
    for i in 1..=n {
        if check(i) {
            count += 1;
            // eprintln!("{}", i);
        }
    }
    println!("{}", count);
}

fn check(n: u64) -> bool {
    let mut ok = true;
    let s: Vec<u32> = n.to_string().chars().map(|x| x.to_digit(10).unwrap()).collect();
    let prev = s[0];
    for d in s {
        if (d as i64 - prev as i64).abs() > 1 {
            ok = false;
            break;
        }
    }

    ok
}

そこでルンルン数の総数が高速に計算できれば、二分探索と組み合わせてk番目のルンルン数が見つけられるのではないかと考えた。 総数を数えるには、桁DPの考え方が適用できる。

本番中には実装が終わらなかったが、その後仕上げたのが次のコード。

use proconio::input;

#[allow(unused_macros)]
macro_rules! multi_vec {
    ( $elem:expr; $num:expr ) => (vec![$elem; $num]);
    ( $elem:expr; $num:expr, $($rest:expr),* ) => (vec![multi_vec![$elem; $($rest),*]; $num]);
}

fn main() {
    input! {
        k: u64,
    }

    let mut mx = 1;
    while calc(mx) < k {
        mx *= 10;
    }
    // dbg!(mx, calc(mx));
    let mut l = 0;
    let mut r = mx;
    while l + 1 < r {
        let mid = (l + r) / 2;
        if calc(mid) < k {
            l = mid
        } else {
            r = mid
        }
    }
    let ans = r;
    println!("{}", ans);
}

fn calc(m: u64) -> u64 {
    let s: Vec<u8> = m.to_string().chars().map(|x| x as u8 - b'0').collect();
    let l = s.len();
    // dp[i][j] = i桁目まで、jは超えないことが確定したか、kは今の桁数字
    let mut dp = multi_vec![0; l + 1, 2, 10];
    for i in 1..=l {
        let d = s[i - 1] as usize;
        for k in 0..=9 {
            // 超えない確定
            if k > 0 {
                dp[i][1][k] += dp[i - 1][1][k - 1];
            }
            dp[i][1][k] += dp[i - 1][1][k];
            if k < 9 {
                dp[i][1][k] += dp[i - 1][1][k + 1];
            }

            // dbg!(dp[1][1][1]);
            // dbg!(i, k, dp[2][1][2]);

            // 超えるかわからない
            if k > 0 {
                if k < d {
                    dp[i][1][k] += dp[i - 1][0][k - 1];
                } else if k == d {
                    dp[i][0][k] += dp[i - 1][0][k - 1];
                }
            }
            if k < d {
                dp[i][1][k] += dp[i - 1][0][k];
            } else if k == d {
                dp[i][0][k] += dp[i - 1][0][k];
            }
            if k < 9 {
                if k < d {
                    dp[i][1][k] += dp[i - 1][0][k + 1];
                } else if k == d {
                    dp[i][0][k] += dp[i - 1][0][k + 1];
                }
            }

            // dbg!(i, k, dp[2][1][2]);

            // この桁から始まる数字
            if k > 0 {
                if i == 1 {
                    if k < d {
                        dp[i][1][k] += 1;
                    } else if k == d {
                        dp[i][0][k] += 1;
                    }
                } else if i > 1 {
                    dp[i][1][k] += 1;
                }
            }

            // dbg!(i, k, dp[2][1][2]);
        }
    }

    // dbg!(m, &dp);

    let mut ret = 0;
    for i in 0..=9 {
        ret += dp[l][0][i] + dp[l][1][i]
    }

    ret
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn calc_test() {
        assert_eq!(calc(1), 1);
        assert_eq!(calc(2), 2);
        assert_eq!(calc(9), 9);
        assert_eq!(calc(10), 10);
        assert_eq!(calc(11), 11);
        assert_eq!(calc(12), 12);
        assert_eq!(calc(21), 13);
        assert_eq!(calc(23), 15);
    }
}