AC

许诺  •  1个月前


#include <bits/stdc++.h>
#include <vector>

using namespace std;

int K, B; int dp[32][32][2];

int dfs(int pos, int count, int tight, vector& digits) {

if (pos == digits.size()) {
    return count == K ? 1 : 0;
}
if (dp[pos][count][tight] != -1) {
    return dp[pos][count][tight];
}
int limit = tight ? digits[pos] : 1;
int res = 0;
for (int d = 0; d <= limit; ++d) {
    if (d > 1) break;
    int new_tight = tight && (d == limit);
    int new_count = count + (d == 1);
    if (new_count > K) continue;
    res += dfs(pos + 1, new_count, new_tight, digits);
}
dp[pos][count][tight] = res;
return res;

}

int solve(int num) {

if (num == 0) return 0;
vector<int> digits;
while (num > 0) {
    digits.push_back(num % B);
    num /= B;
}
reverse(digits.begin(), digits.end());
memset(dp, -1, sizeof(dp));
return dfs(0, 0, 1, digits);

}

int main() {

int X, Y;
cin >> X >> Y >> K >> B;
int ans = solve(Y) - solve(X - 1);
cout << ans << endl;
return 0;

}


评论:

请先登录,才能进行评论