[CF] E. Mycraft Sand Sort - Codeforces Round 1005 (Div. 2)

https://codeforces.com/contest/2064/problem/E

题目大意

给出 $n$ 长的排列 $p$ 和 $n$ 长的颜色数组 $c$。从上到下,从 $i = 1 \sim n$ 将 $p_i$ 个颜色为 $c_i$ 的左侧对齐横向排列,之后让其按照重力规则下落,问有多少个不同的排列和颜色方案的二元组,可以得到和题目给出的排列颜色组,下落之后相同的图形。

例如 $p = [4, 2, 3, 1, 5], c = [2, 1, 4, 1, 5]$

原图为

2 2 2 2
1 1
4 4 4 
1
5 5 5 5 5

下落后为

2
1 2
4 1 2
1 4 4 2
5 5 5 5 5

$1 \le n \le 2 \times 10^5$。$p$ 为 $1 \sim n$ 的排列。$1 \le c \le n$。

简要题解

观察:

  1. 第一列在下落前后不会变。
  2. 第一列也就颜色数组!就是说颜色数组其实不会变。
  3. 对于第 $i$ 列,剩下的都是长度,即 $p_j \ge i$ 的,而这些还会按照原先的长度排列。
  4. 当两种相同颜色的 $c_i = c_j$,其中间夹的其他颜色长度都比其短时,即 $\min(p_i, p_j) > \max(p_k), k \in [i + 1, j - 1]$,意味着一定有在 $\max(p_k) + 1$ 列时,这两个的颜色块会并在一起,并且意味着这两个排列其实是可交换的。
  5. 可交换的条件实际上更松。实际上对于 $i$ 它可以与其两侧所有的这样的 $j$ 交换:$c_j = c_i$ 且 $p_j > p_i$ 且 $i, j$ 之间没有大于 $p_j$ 的其他颜色的 $p_k$。 因为由我们上面的结论我们知道这两个颜色总会在 $\max(p_k) + 1, k \in [i + 1, j - 1]$ 这个时点前并到一簇连续的同颜色区间。而一旦两簇并到了一起,它俩之间的相对顺序就不重要了。
  6. 很重要的一点,$p$ 只可能在颜色内重排,因为我们可以把其他颜色都抽掉,剩下的就代表了 $p_i$ 是哪些值,这也是唯一的。(做题时没细想这个点,但是很重要,这说明 5 中的情况就是完全的)

由 5 我们知道最后形成的是一个类似树形的结构,当 $i$ 连到了比它更大的 $j$ 上时,只有两种情况,它与 $j$ 的可交换范围相同,它是 $j$ 的可交换范围的子集。我们总是优先安排可交换范围更小的一些,也就是 $p$ 更小的一些。

于是我们第一步总是找到每个位置对应的可交换区间:将所有 $p$ 插入到一棵最大值线段树,然后依次枚举颜色,将这个颜色的所有位置置为 $0$,之后对这个颜色的每个位置 $i$ 左右两侧二分最大的区间,使得这个区间的最值不超过 $p_i$。这个过程写的两个 $\log$ 的做法(二分 + 线段树查询,不太确定能不能直接线段树二分,以及复杂度会不会更好)。查询结束后恢复线段树中当前颜色所有对应位置的值。这样就可以对于每种颜色只考虑其他颜色的限制了。(题解这部分给的是直接用并查集维护每个块的边界,复杂度会好一些)

之后因为优先安排小的不会使大的无法安排,因此我们每种颜色每个位置只需要关注,其可操作区间有多少更大同色可交换的,算上其本身的位置,这就是 $i$ 所有可被安排的可能性,乘到答案上即可(比它更小的也可能与其交换,但是本质是,更小的被安排后会占掉一些位置,而这部分可能性,更小的会计算,因此只需考虑更大的还没占的位置即可)。

复杂度

$T$:$O(n \log ^ 2 (n))$

$S$:$O(n)$

代码实现

#include <bits/stdc++.h>
using namespace std;

int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();

using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;

template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) { 
    return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) { 
    return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z 
    if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
    if(x<z) { y=x; x=z; } else if(y<z) y=z; }

// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
template<long long Mo=998244353> struct ModInt {
  static long long MO;
  static void setMo(long long mo) { MO = mo; }
  long long x;
  ModInt(long long x=0) : x(x){ norm(); }
  friend istream &operator>>(istream& in, ModInt &B) { in>>B.x; return in; }
  friend ostream &operator<<(ostream& out, const ModInt &B) { 
    out<<B.x; return out; }
  // ModInt operator=(int x_) { x=x_; norm(); return *this; }
  void norm() { x = (x%MO + MO) % MO; }
  long long get() { return x; }

  ModInt operator-() const { return ModInt(MO - x); }
  ModInt operator+=(const ModInt &B) { x+=B.x; if(x>=MO) x-=MO; return *this; }
  ModInt operator-=(const ModInt &B) { x-=B.x; if(x<0) x+=MO; return *this; }
  ModInt operator*=(const ModInt &B) { x=x*B.x%MO; return *this; }
  ModInt operator+(const ModInt &B) const { ModInt ans=*this; return ans+=B; }
  ModInt operator-(const ModInt &B) const { ModInt ans=*this; return ans-=B; }
  ModInt operator*(const ModInt &B) const { ModInt ans=*this; return ans*=B; }
  ModInt operator^(long long n) const  {
    ModInt a=*this; ModInt ans(1);
    while(n) { if(n&1) ans*=a; a*=a; n>>=1; }
    return ans;
  }
  ModInt inv() const { return (*this)^(MO-2); } // if MO is prime
  ModInt operator/=(const ModInt &B) { (*this)*=B.inv(); return *this; }
  ModInt operator/(const ModInt &B) const { ModInt ans=*this; return ans/=B; }

  bool operator<(const ModInt &B) const { return x<B.x; }
  bool operator==(const ModInt &B) const { return x==B.x; }
  bool operator!=(const ModInt &B) const { return x!=B.x; }
};
template<long long Mo> long long ModInt<Mo>::MO = Mo;
typedef ModInt<998244353> Mint;
// typedef ModInt<1'000'000'007> Mint;

template<typename T> struct GetZero { T operator()() const { return T(0); } };
template<typename T,
         typename OpPlus=plus<T>,typename OpMinus=minus<T>,
         typename Zero=GetZero<T> >
struct BIT {
  static int lowbit(int x) { return x&(-x); }
  constexpr static OpPlus opp{};
  constexpr static OpMinus opm{};
  constexpr static Zero zero{};
  int n;
  vector<T> tree; // tree[i] -> sum of [i-lowbit(i)+1,i]
  BIT(int n_=0):n(n_),tree(n+1,zero()) {}
  void init(int n_) { n=n_; tree.assign(n+1,zero()); }
  void init(const vector<T> &vec) { // v[0 ~ n_-1]
    n=vec.size();
    vector<T> tmp(n+1,zero());
    for(int i=1;i<=n;i++) tmp[i]=opp(tmp[i-1],vec[i-1]);
    for(int i=1;i<=n;i++) tree[i]=opm(tmp[i],tmp[i-lowbit(i)]);
  }
  void add(int p,T v) { 
    for(;p<=n;p+=lowbit(p)) tree[p]=opp(tree[p],v);
  }
  T sum(int p) {
    T ans=zero();
    for(;p;p-=lowbit(p)) ans=opp(ans,tree[p]);
    return ans;
  }
  T sum(int l,int r) { return opm(sum(r),sum(l-1)); }
};

/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/

struct SegT {
  int n;
  vector<int> mx;
  SegT(int n) : n(n), mx(n << 2) {}
  int P, V;
  void update(int o, int l, int r) {
    if (l == r) {
      mx[o] = V;
      return;
    }
    int mid = (l + r) >> 1;
    if (P <= mid) update(o << 1, l, mid);
    else update(o << 1 | 1, mid + 1, r);

    mx[o] = max(mx[o << 1], mx[o << 1 | 1]);
  }

  int L, R;
  int query(int o, int l, int r) {
    if (L <= l && r <= R) {
      return mx[o];
    }
    int mid = (l + r) >> 1;
    int ans = 0;
    if (L <= mid) cmax(ans, query(o << 1, l, mid));
    if (mid < R) cmax(ans, query(o << 1 | 1, mid + 1, r));
    return ans;
  }

  int bstl() {
    int l = 0, r = R - 1;
    int ans = R, mid;
    while (l <= r) {
      mid = (l + r) >> 1;
      L = mid;
      if (query(1, 0, n - 1) <= V) {
        ans = mid;
        r = mid - 1;
      } else {
        l = mid + 1;
      }
    }
    return ans;
  }
  int bstr() {
    int l = L + 1, r = n - 1;
    int ans = L, mid;
    while (l <= r) {
      mid = (l + r) >> 1;
      R = mid;
      if (query(1, 0, n - 1) <= V) {
        ans = mid;
        l = mid + 1;
      } else {
        r = mid - 1;
      }
    }
    return ans;
  }
};

void solve() {
  int n; cin >> n;
  vector<int> p(n), c(n);
  for (int& i : p) cin >> i;
  for (int& i : c) cin >> i;

  vector<int> pos(n + 1);
  for (int i = 0; i < n; i++) {
    pos[p[i]] = i;
  }
  vector<vector<int>> col(n + 1);
  vector<vector<int>> colp(n + 1);
  for (int i = 0; i < n; i++) {
    col[c[i]].push_back(i);
    colp[c[i]].push_back(p[i]);
  }

  SegT segt(n);
  for (int i = 0; i < n; i++) {
    segt.P = i;
    segt.V = p[i];
    segt.update(1, 0, n - 1);
  }

  vector<int> l(n), r(n);
  for (int ii = 1; ii <= n; ii++) {
    if (col[ii].empty()) continue;

    for (int i : col[ii]) {
      segt.P = i;
      segt.V = 0;
      segt.update(1, 0, n - 1);
    }

    for (int i : col[ii]) {
      segt.R = i;
      segt.V = p[i];
      l[i] = segt.bstl();
      segt.L = i;
      r[i] = segt.bstr();
    }

    for (int i : col[ii]) {
      segt.P = i;
      segt.V = p[i];
      segt.update(1, 0, n - 1);
    }
  }

  // for (int i = 0; i < n; i++) {
  //   cerr << i << ' ' << l[i] << ' ' << r[i] << '\n';
  // }

  Mint ans = 1;
  BIT<int> bit(n);
  for (int ii = 1; ii <= n; ii++) {
    if (colp[ii].empty()) continue;

    sort(colp[ii].rbegin(), colp[ii].rend());

    for (int i : colp[ii]) {
      i = pos[i];
      bit.add(i + 1, 1);
      ans *= bit.sum(l[i] + 1, r[i] + 1);
    }
    for (int i : colp[ii]) {
      i = pos[i];
      bit.add(i + 1, -1);
    }
  }

  cout << ans << '\n';
}

int main() {
  int t = 1; 
  cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

数据生成

#include <bits/stdc++.h>
using namespace std;

int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();

using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;

template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z 
    if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
    if(x<z) { y=x; x=z; } else if(y<z) y=z; }

mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());

/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/

inline int rndi() { return rnd(); }
inline int rndi(int n) { return rnd()%n; }
inline int rndi(int l, int r) { return rnd() % (r - l + 1) + l; }
inline LL rndll() { return rnd_64(); }
inline LL rndll(LL n) { return rnd_64() % n; }
inline LL rndll(LL l, LL r) { return rnd_64() % (r - l + 1) + l; }
inline char rndc() { return rnd() % 26 + 'a'; }
inline char rndC() { return rnd() % 26 + 'A'; }
inline char rnddig() { return rnd() % 10 + '0'; }
inline char rndcha() { 
  int v = rnd() % 52;
  return v < 26 ? v + 'a' : (v - 26 + 'A');
}

template<typename T>
void shuffle(vector<T>& vec) { shuffle(vec.begin(), vec.end(), rnd); }

vector<int> rnd_permu(int n, int from = 0) {
  vector<int> vec(n);
  iota(vec.begin(), vec.end(), from);
  shuffle(vec);
  return vec;
}

void solve() {
  int n = 5; cout << n << '\n';
  auto p = rnd_permu(n, 1);
  for (int i : p) cout << i << ' ';
  cout << '\n';

  for (int i = 0; i < n; i++) {
    cout << rndi(1, n) << ' ';
  }
  cout << '\n';
}

int main() {
  int t = 1; cout << t << '\n';
  while (t--) {
    solve();
  }
  return 0;
}

暴力

#include <bits/stdc++.h>
using namespace std;

int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();

using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;

template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) { 
    return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) { 
    return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z 
    if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
    if(x<z) { y=x; x=z; } else if(y<z) y=z; }

// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());

/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/

void solve() {
  int n; cin >> n;
  vector<int> p(n), c(n);
  for (int& i : p) cin >> i;
  for (int& i : c) cin >> i;

  auto q = p;
  sort(q.begin(), q.end());

  auto getcol = [&](const vector<int>& a) -> vector<vector<int>> {
    vector<vector<int>> ans;
    for (int i = 1; i <= n; i++) {
      vector<int> cur;
      for (int j = 0; j < n; j++) {
        if (a[j] >= i) {
          cur.push_back(c[j]);
        }
      }

      ans.push_back(cur);
    }
    return ans;
  };

  auto colp = getcol(p);
  
  int ans = 0;
  do {
    if (getcol(q) == colp) ans++;
  } while (next_permutation(q.begin(), q.end()));

  cout << ans << '\n';
}

int main() {
  int t = 1; 
  cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

一组有启发性的小数据

1
5
5 3 1 4 2
5 2 3 5 5
Prev: [CF] C. Vasya And The Mushrooms - Educational Codeforces Round 48 (Rated for Div. 2)