[CF] E. Vasya and Binary String - Educational Codeforces Round 59 (Rated for Div. 2)

https://codeforces.com/contest/1107/problem/E

题目大意

给出一个 $n$ 长的 $01$ 串 $S$ 和数组 $a$。每次操作可以从中删除一个连续的,相同字符的段,并把剩余的段按顺序连起来。每次删掉 $len$ 长的段的收益为 $a_{len}$,问最大收益是多少。

$1 \le n \le 100$。$1 \le a_i \le 10^9$。

简要题解

观察:

  1. 因为权重都是正的,因此最后一定是把所有颜色删完最好。
  2. $a$ 并不是实际上的最优值,因为可以删多次小段,我们可以完全背包解决这个部分。
  3. 对于整段最后一次一定是删掉了某种颜色所有剩余的频率。
  4. $n$ 不大可以考虑区间 dp。
  5. 对于小的区间,每次显然最后剩余一种颜色即可,因为大区间也只用处理一种颜色。

有了这些观察我们只要定义 $dp[l][r][cnt][col]$ 为区间 $[l, r]$ 最后剩下 $cnt$ 个 $col$ 时,所能得到的最大权值即可。转移只需考虑枚举 $[l, mid]$ 或 $[mid + 1, r]$ 提供所有的 $cnt$ 频率,以及 $l$ 或者 $r$ 提供一个频率,其他部分提供 $cnt - 1$ 即可。这里并不需要枚举所有中间点并枚举两侧的频率,例如有一个频率分别为 $x, y$ 的切割,必然也会有一个上述的等价切割。

注意因为转移需要某段完全处理掉两种颜色的情况,因此我们不妨在求完所有 $len = [1, r - l + 1]$ 的情况后将其所有频率处理掉转移到 $len = 0$。

复杂度

$T$:$O(n ^ 4)$:这个显然不是跑的那么满,实际跑了 $124ms$。

$S$:$O(n ^ 3)$

代码实现

#include <bits/stdc++.h>
using namespace std;

int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();

using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;

template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) { 
    return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) { 
    return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z 
    if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
    if(x<z) { y=x; x=z; } else if(y<z) y=z; }

// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());

/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/

const int M = 105;

LL mxcost[M];
LL dp[M][M][M][2];

void solve() {
  int n;
  string s;
  cin >> n >> s;
  vector<int> a(n + 1);
  for (int i = 1; i <= n; i++) {
    cin >> a[i];
  }

  for (int i = 1; i <= n; i++) {
    mxcost[i] = 0;
  }
  for (int i = 1; i <= n; i++) {
    for (int j = i; j <= n; j++) {
      cmax(mxcost[j], mxcost[j - i] + a[i]);
    }
  }

  memset(dp, -1, sizeof(dp));

  for (int i = 0; i < n; i++) {
    dp[i][i][0][0] = dp[i][i][0][1] = mxcost[1];
    dp[i][i][1][s[i] - '0'] = 0;
  }

  for (int len = 2; len <= n; len++) {
    for (int l = 0, r = len - 1; r < n; l++, r++) {
      for (int cnt = 1; cnt <= len; cnt++) {
        for (int col = 0; col < 2; col++) {
          if (dp[l + 1][r][cnt][col] != -1) {
            cmax(dp[l][r][cnt][col], dp[l + 1][r][cnt][col] + mxcost[1]);
          }
          if (s[l] - '0' == col && dp[l + 1][r][cnt - 1][col] != -1) {
            cmax(dp[l][r][cnt][col], dp[l + 1][r][cnt - 1][col]);
          }
          if (dp[l][r - 1][cnt][col] != -1) {
            cmax(dp[l][r][cnt][col], dp[l][r - 1][cnt][col] + mxcost[1]);
          }
          if (s[r] - '0' == col && dp[l][r - 1][cnt - 1][col] != -1) {
            cmax(dp[l][r][cnt][col], dp[l][r - 1][cnt - 1][col]);
          }

          for (int mid = l; mid < r; mid++) {
            if (dp[l][mid][cnt][col] != -1 && dp[mid + 1][r][0][col] != -1) {
              cmax(dp[l][r][cnt][col], dp[l][mid][cnt][col] + dp[mid + 1][r][0][col]);
            }
            if (dp[l][mid][0][col] != -1 && dp[mid + 1][r][cnt][col] != -1) {
              cmax(dp[l][r][cnt][col], dp[l][mid][0][col] + dp[mid + 1][r][cnt][col]);
            }
          }
          
          if (dp[l][r][cnt][col] != -1) {
            cmax(dp[l][r][0][col], dp[l][r][cnt][col] + mxcost[cnt]);
            cmax(dp[l][r][0][col ^ 1], dp[l][r][cnt][col] + mxcost[cnt]);
          }
          // if (dp[l][r][cnt][col] != -1) {
          //   cerr << l << ' ' << r << ' ' << cnt << ' ' << col << ' ' << dp[l][r][cnt][col] << endl;
          // } 
        }
      }
    }
  }

  cout << max(dp[0][n - 1][0][0], dp[0][n - 1][0][1]) << '\n';
}

int main() {
  int t = 1; 
  // cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}
Prev: [CF] A. Object Identification - Codeforces Round 1004 (Div. 1)
Next: [CF] D. Compression - Educational Codeforces Round 59 (Rated for Div. 2)