[CF] E. Selling Souvenirs - Educational Codeforces Round 21

https://codeforces.com/contest/808/problem/E

题目大意

给出 $n \ (n \le 10^5)$ 个物品和总容积为 $m \ (m \le 3 \times 10^5)$ 的背包。n 个物品的体积和价值分别为 $w_i \ (1 \le w_i \le 3)$ 和 $c_i \ (1 \le c_i \le 10^9)$。

问可以装入背包的(不超过 $m$ 的)最大价值。

简要题解

显然这是一个 $01$ 背包问题,但是空间和个数很大,无法使用标准的做法。注意到 $w_i$ 只有 $3$ 种,因此我们想到能不能枚举。显然每种体积都会优先从大的开始使用。

显然 $50000 \times 50000$ 的复杂度不是题目设计的算法。(不过事实上这题给了 $2s$,常数较低的 $n ^ 2$ 是可以卡过去的,比如下面 $n^2$ 程序用了 $1546ms$)

由于 $1$ 的性质比较好,我们考虑枚举 $2$ 或者 $3$,dp 处理另一对。

我们考虑 $dp[i]$ 表示 $i$ 体积下最好的解。此时假设有不同的组合都可以得到该最优值,例如 $cnt_1 = x_1, cnt_2 = y_1$ 和 $cnt_1 = x_1, cnt_2 = y_1$ 假定 $c_1$ 和 $c_2$ 是两个从大到小排好序的数组,则这意味着 $dp[i] = \sum_{j = 1}^{x_1} c_1[j] + \sum_{j = 1}^{y_1} c_2[j] = \sum_{j = 1}^{x_2} c_1[j] + \sum_{j = 1}^{y_2} c_2[j]$ (假设 $x_1 < x_2$ 则 $sum_{i = x_1 + 1}^{x_2} = sum_{i = x_1 + 1}^{x_2}$)也就是说一段体积 $2$ 的和与一段体积 $1$ 的和是一样的。那么对于更大的 $dp[i + 2]$,无论 $dp[i]$ 我们记录的是哪一种路径转移过来的,我们都还可以使用一对 $1$ 或者某个 $2$,这些 $2$ 的权值都应相同,这些 $1$ 如果多于一对应该也都相同,否则我们显然可以得到更大权值。对于 $dp[i + 1]$ 看似如果我们选择多用 $cnt_1$ 则可能在此时没有可用的 $1$,但实际上我们还有从 $dp[i - 1]$ 来的转移,对于 $i - 1$ 我们将至少有一个多出的 $1$ 和一个多出的 $2$ 则此时我们沿着这条路径可以得到正确的转移。也就是说,其实转移的过程我们只需要知道某条最优解的路径即可。

思考:是否枚举 $1$ 对于 $2, 3$ 也是可行的,或者说是否这个性质对于任何的一对 $x, y$ 都是成立的。

复杂度

$T$:$O(n \log n)$:$log$ 是来自于排序。

$S$:$O(n)$

代码实现

$O(n \log n)$ 的做法只要 $93ms$

#include <bits/stdc++.h>
using namespace std;

int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();

using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;

template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) { 
    return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) { 
    return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z 
    if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
    if(x<z) { y=x; x=z; } else if(y<z) y=z; }

// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());

/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/

typedef array<LL, 3> A3;

void solve() {
  int n, m; cin >> n >> m;
  int w, c;
  vector<vector<LL>> cs(4);
  for (int i = 0; i < n; i++) {
    cin >> w >> c;
    cs[w].push_back(c);
  }

  for (int i = 1; i <= 3; i++) {
    sort(cs[i].rbegin(), cs[i].rend());
  }

  vector<A3> dp(m + 2);
  for (int i = 0; i < m; i++) {
    auto [v, c1, c2] = dp[i];
    if (c1 < (int)cs[1].size()) cmax(dp[i + 1], {v + cs[1][c1], c1 + 1, c2});
    if (c2 < (int)cs[2].size()) cmax(dp[i + 2], {v + cs[2][c2], c1, c2 + 1});
  }
  for (int i = 1; i <= m; i++) {
    cmax(dp[i], dp[i - 1]);
  }

  LL ans = dp[m][0];
  LL sum = 0;
  for (int i = 0; i < (int)cs[3].size(); i++) {
    sum += cs[3][i];
    int left = m - 3 * (i + 1);
    if (left < 0) break;
    cmax(ans, dp[left][0] + sum);
  }
  cout << ans << '\n';
}

int main() {
  int t = 1; 
  // cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

枚举 $2$

typedef array<LL, 3> A3;

void solve() {
  int n, m; cin >> n >> m;
  int w, c;
  vector<vector<LL>> cs(4);
  for (int i = 0; i < n; i++) {
    cin >> w >> c;
    cs[w].push_back(c);
  }

  for (int i = 1; i <= 3; i++) {
    sort(cs[i].rbegin(), cs[i].rend());
  }

  vector<A3> dp(m + 3);
  for (int i = 0; i < m; i++) {
    auto [v, c1, c3] = dp[i];
    if (c1 < (int)cs[1].size()) cmax(dp[i + 1], {v + cs[1][c1], c1 + 1, c3});
    if (c3 < (int)cs[3].size()) cmax(dp[i + 3], {v + cs[3][c3], c1, c3 + 1});
  }
  for (int i = 1; i <= m; i++) {
    cmax(dp[i], dp[i - 1]);
  }

  LL ans = dp[m][0];
  LL sum = 0;
  for (int i = 0; i < (int)cs[2].size(); i++) {
    sum += cs[2][i];
    int left = m - 2 * (i + 1);
    if (left < 0) break;
    cmax(ans, dp[left][0] + sum);
  }
  cout << ans << '\n';
}

时间复杂度 $n^2$ 的代码:

void solve() {
  int n, m; cin >> n >> m;
  int w, c;
  vector<vector<LL>> cs(4);
  for (int i = 0; i < n; i++) {
    cin >> w >> c;
    cs[w].push_back(c);
  }

  for (int i = 1; i <= 3; i++) {
    sort(cs[i].rbegin(), cs[i].rend());
    cs[i].insert(cs[i].begin(), 0);
    for (int j = 1; j < (int)cs[i].size(); j++) {
      cs[i][j] += cs[i][j - 1];
    }
  }

  LL ans = 0;
  for (int i = 0; i < (int)cs[2].size(); i++) {
    for (int j = 0; j < (int)cs[3].size(); j++) {
      int left = m - i * 2 - j * 3;
      if (left < 0) break;
      
      LL sum = cs[2][i] + cs[3][j] + cs[1][min(left, (int)cs[1].size() - 1)];
      cmax(ans, sum);
    }
  }

  cout << ans << '\n';
}
Prev: [CF] D. Array Division - Educational Codeforces Round 21
Next: [CF] B. Buggy Robot - Educational Codeforces Round 32