https://codeforces.com/contest/808/problem/E
题目大意
给出 $n \ (n \le 10^5)$ 个物品和总容积为 $m \ (m \le 3 \times 10^5)$ 的背包。n 个物品的体积和价值分别为 $w_i \ (1 \le w_i \le 3)$ 和 $c_i \ (1 \le c_i \le 10^9)$。
问可以装入背包的(不超过 $m$ 的)最大价值。
简要题解
显然这是一个 $01$ 背包问题,但是空间和个数很大,无法使用标准的做法。注意到 $w_i$ 只有 $3$ 种,因此我们想到能不能枚举。显然每种体积都会优先从大的开始使用。
显然 $50000 \times 50000$ 的复杂度不是题目设计的算法。(不过事实上这题给了 $2s$,常数较低的 $n ^ 2$ 是可以卡过去的,比如下面 $n^2$ 程序用了 $1546ms$)
由于 $1$ 的性质比较好,我们考虑枚举 $2$ 或者 $3$,dp 处理另一对。
我们考虑 $dp[i]$ 表示 $i$ 体积下最好的解。此时假设有不同的组合都可以得到该最优值,例如 $cnt_1 = x_1, cnt_2 = y_1$ 和 $cnt_1 = x_1, cnt_2 = y_1$ 假定 $c_1$ 和 $c_2$ 是两个从大到小排好序的数组,则这意味着 $dp[i] = \sum_{j = 1}^{x_1} c_1[j] + \sum_{j = 1}^{y_1} c_2[j] = \sum_{j = 1}^{x_2} c_1[j] + \sum_{j = 1}^{y_2} c_2[j]$ (假设 $x_1 < x_2$ 则 $sum_{i = x_1 + 1}^{x_2} = sum_{i = x_1 + 1}^{x_2}$)也就是说一段体积 $2$ 的和与一段体积 $1$ 的和是一样的。那么对于更大的 $dp[i + 2]$,无论 $dp[i]$ 我们记录的是哪一种路径转移过来的,我们都还可以使用一对 $1$ 或者某个 $2$,这些 $2$ 的权值都应相同,这些 $1$ 如果多于一对应该也都相同,否则我们显然可以得到更大权值。对于 $dp[i + 1]$ 看似如果我们选择多用 $cnt_1$ 则可能在此时没有可用的 $1$,但实际上我们还有从 $dp[i - 1]$ 来的转移,对于 $i - 1$ 我们将至少有一个多出的 $1$ 和一个多出的 $2$ 则此时我们沿着这条路径可以得到正确的转移。也就是说,其实转移的过程我们只需要知道某条最优解的路径即可。
思考:是否枚举 $1$ 对于 $2, 3$ 也是可行的,或者说是否这个性质对于任何的一对 $x, y$ 都是成立的。
复杂度
$T$:$O(n \log n)$:$log$ 是来自于排序。
$S$:$O(n)$
代码实现
$O(n \log n)$ 的做法只要 $93ms$
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) {
return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) {
return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
typedef array<LL, 3> A3;
void solve() {
int n, m; cin >> n >> m;
int w, c;
vector<vector<LL>> cs(4);
for (int i = 0; i < n; i++) {
cin >> w >> c;
cs[w].push_back(c);
}
for (int i = 1; i <= 3; i++) {
sort(cs[i].rbegin(), cs[i].rend());
}
vector<A3> dp(m + 2);
for (int i = 0; i < m; i++) {
auto [v, c1, c2] = dp[i];
if (c1 < (int)cs[1].size()) cmax(dp[i + 1], {v + cs[1][c1], c1 + 1, c2});
if (c2 < (int)cs[2].size()) cmax(dp[i + 2], {v + cs[2][c2], c1, c2 + 1});
}
for (int i = 1; i <= m; i++) {
cmax(dp[i], dp[i - 1]);
}
LL ans = dp[m][0];
LL sum = 0;
for (int i = 0; i < (int)cs[3].size(); i++) {
sum += cs[3][i];
int left = m - 3 * (i + 1);
if (left < 0) break;
cmax(ans, dp[left][0] + sum);
}
cout << ans << '\n';
}
int main() {
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}
枚举 $2$
typedef array<LL, 3> A3;
void solve() {
int n, m; cin >> n >> m;
int w, c;
vector<vector<LL>> cs(4);
for (int i = 0; i < n; i++) {
cin >> w >> c;
cs[w].push_back(c);
}
for (int i = 1; i <= 3; i++) {
sort(cs[i].rbegin(), cs[i].rend());
}
vector<A3> dp(m + 3);
for (int i = 0; i < m; i++) {
auto [v, c1, c3] = dp[i];
if (c1 < (int)cs[1].size()) cmax(dp[i + 1], {v + cs[1][c1], c1 + 1, c3});
if (c3 < (int)cs[3].size()) cmax(dp[i + 3], {v + cs[3][c3], c1, c3 + 1});
}
for (int i = 1; i <= m; i++) {
cmax(dp[i], dp[i - 1]);
}
LL ans = dp[m][0];
LL sum = 0;
for (int i = 0; i < (int)cs[2].size(); i++) {
sum += cs[2][i];
int left = m - 2 * (i + 1);
if (left < 0) break;
cmax(ans, dp[left][0] + sum);
}
cout << ans << '\n';
}
时间复杂度 $n^2$ 的代码:
void solve() {
int n, m; cin >> n >> m;
int w, c;
vector<vector<LL>> cs(4);
for (int i = 0; i < n; i++) {
cin >> w >> c;
cs[w].push_back(c);
}
for (int i = 1; i <= 3; i++) {
sort(cs[i].rbegin(), cs[i].rend());
cs[i].insert(cs[i].begin(), 0);
for (int j = 1; j < (int)cs[i].size(); j++) {
cs[i][j] += cs[i][j - 1];
}
}
LL ans = 0;
for (int i = 0; i < (int)cs[2].size(); i++) {
for (int j = 0; j < (int)cs[3].size(); j++) {
int left = m - i * 2 - j * 3;
if (left < 0) break;
LL sum = cs[2][i] + cs[3][j] + cs[1][min(left, (int)cs[1].size() - 1)];
cmax(ans, sum);
}
}
cout << ans << '\n';
}
Next: [CF] B. Buggy Robot - Educational Codeforces Round 32