https://codeforces.com/contest/2064/problem/E
题目大意
给出 $n$ 长的排列 $p$ 和 $n$ 长的颜色数组 $c$。从上到下,从 $i = 1 \sim n$ 将 $p_i$ 个颜色为 $c_i$ 的左侧对齐横向排列,之后让其按照重力规则下落,问有多少个不同的排列和颜色方案的二元组,可以得到和题目给出的排列颜色组,下落之后相同的图形。
例如 $p = [4, 2, 3, 1, 5], c = [2, 1, 4, 1, 5]$
原图为
2 2 2 2
1 1
4 4 4
1
5 5 5 5 5
下落后为
2
1 2
4 1 2
1 4 4 2
5 5 5 5 5
$1 \le n \le 2 \times 10^5$。$p$ 为 $1 \sim n$ 的排列。$1 \le c \le n$。
简要题解
观察:
- 第一列在下落前后不会变。
- 第一列也就颜色数组!就是说颜色数组其实不会变。
- 对于第 $i$ 列,剩下的都是长度,即 $p_j \ge i$ 的,而这些还会按照原先的长度排列。
- 当两种相同颜色的 $c_i = c_j$,其中间夹的其他颜色长度都比其短时,即 $\min(p_i, p_j) > \max(p_k), k \in [i + 1, j - 1]$,意味着一定有在 $\max(p_k) + 1$ 列时,这两个的颜色块会并在一起,并且意味着这两个排列其实是可交换的。
- 可交换的条件实际上更松。实际上对于 $i$ 它可以与其两侧所有的这样的 $j$ 交换:$c_j = c_i$ 且 $p_j > p_i$ 且 $i, j$ 之间没有大于 $p_j$ 的其他颜色的 $p_k$。 因为由我们上面的结论我们知道这两个颜色总会在 $\max(p_k) + 1, k \in [i + 1, j - 1]$ 这个时点前并到一簇连续的同颜色区间。而一旦两簇并到了一起,它俩之间的相对顺序就不重要了。
- 很重要的一点,$p$ 只可能在颜色内重排,因为我们可以把其他颜色都抽掉,剩下的就代表了 $p_i$ 是哪些值,这也是唯一的。(做题时没细想这个点,但是很重要,这说明 5 中的情况就是完全的)
由 5 我们知道最后形成的是一个类似树形的结构,当 $i$ 连到了比它更大的 $j$ 上时,只有两种情况,它与 $j$ 的可交换范围相同,它是 $j$ 的可交换范围的子集。我们总是优先安排可交换范围更小的一些,也就是 $p$ 更小的一些。
于是我们第一步总是找到每个位置对应的可交换区间:将所有 $p$ 插入到一棵最大值线段树,然后依次枚举颜色,将这个颜色的所有位置置为 $0$,之后对这个颜色的每个位置 $i$ 左右两侧二分最大的区间,使得这个区间的最值不超过 $p_i$。这个过程写的两个 $\log$ 的做法(二分 + 线段树查询,不太确定能不能直接线段树二分,以及复杂度会不会更好)。查询结束后恢复线段树中当前颜色所有对应位置的值。这样就可以对于每种颜色只考虑其他颜色的限制了。(题解这部分给的是直接用并查集维护每个块的边界,复杂度会好一些)
之后因为优先安排小的不会使大的无法安排,因此我们每种颜色每个位置只需要关注,其可操作区间有多少更大同色可交换的,算上其本身的位置,这就是 $i$ 所有可被安排的可能性,乘到答案上即可(比它更小的也可能与其交换,但是本质是,更小的被安排后会占掉一些位置,而这部分可能性,更小的会计算,因此只需考虑更大的还没占的位置即可)。
复杂度
$T$:$O(n \log ^ 2 (n))$
$S$:$O(n)$
代码实现
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) {
return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) {
return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
template<long long Mo=998244353> struct ModInt {
static long long MO;
static void setMo(long long mo) { MO = mo; }
long long x;
ModInt(long long x=0) : x(x){ norm(); }
friend istream &operator>>(istream& in, ModInt &B) { in>>B.x; return in; }
friend ostream &operator<<(ostream& out, const ModInt &B) {
out<<B.x; return out; }
// ModInt operator=(int x_) { x=x_; norm(); return *this; }
void norm() { x = (x%MO + MO) % MO; }
long long get() { return x; }
ModInt operator-() const { return ModInt(MO - x); }
ModInt operator+=(const ModInt &B) { x+=B.x; if(x>=MO) x-=MO; return *this; }
ModInt operator-=(const ModInt &B) { x-=B.x; if(x<0) x+=MO; return *this; }
ModInt operator*=(const ModInt &B) { x=x*B.x%MO; return *this; }
ModInt operator+(const ModInt &B) const { ModInt ans=*this; return ans+=B; }
ModInt operator-(const ModInt &B) const { ModInt ans=*this; return ans-=B; }
ModInt operator*(const ModInt &B) const { ModInt ans=*this; return ans*=B; }
ModInt operator^(long long n) const {
ModInt a=*this; ModInt ans(1);
while(n) { if(n&1) ans*=a; a*=a; n>>=1; }
return ans;
}
ModInt inv() const { return (*this)^(MO-2); } // if MO is prime
ModInt operator/=(const ModInt &B) { (*this)*=B.inv(); return *this; }
ModInt operator/(const ModInt &B) const { ModInt ans=*this; return ans/=B; }
bool operator<(const ModInt &B) const { return x<B.x; }
bool operator==(const ModInt &B) const { return x==B.x; }
bool operator!=(const ModInt &B) const { return x!=B.x; }
};
template<long long Mo> long long ModInt<Mo>::MO = Mo;
typedef ModInt<998244353> Mint;
// typedef ModInt<1'000'000'007> Mint;
template<typename T> struct GetZero { T operator()() const { return T(0); } };
template<typename T,
typename OpPlus=plus<T>,typename OpMinus=minus<T>,
typename Zero=GetZero<T> >
struct BIT {
static int lowbit(int x) { return x&(-x); }
constexpr static OpPlus opp{};
constexpr static OpMinus opm{};
constexpr static Zero zero{};
int n;
vector<T> tree; // tree[i] -> sum of [i-lowbit(i)+1,i]
BIT(int n_=0):n(n_),tree(n+1,zero()) {}
void init(int n_) { n=n_; tree.assign(n+1,zero()); }
void init(const vector<T> &vec) { // v[0 ~ n_-1]
n=vec.size();
vector<T> tmp(n+1,zero());
for(int i=1;i<=n;i++) tmp[i]=opp(tmp[i-1],vec[i-1]);
for(int i=1;i<=n;i++) tree[i]=opm(tmp[i],tmp[i-lowbit(i)]);
}
void add(int p,T v) {
for(;p<=n;p+=lowbit(p)) tree[p]=opp(tree[p],v);
}
T sum(int p) {
T ans=zero();
for(;p;p-=lowbit(p)) ans=opp(ans,tree[p]);
return ans;
}
T sum(int l,int r) { return opm(sum(r),sum(l-1)); }
};
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
struct SegT {
int n;
vector<int> mx;
SegT(int n) : n(n), mx(n << 2) {}
int P, V;
void update(int o, int l, int r) {
if (l == r) {
mx[o] = V;
return;
}
int mid = (l + r) >> 1;
if (P <= mid) update(o << 1, l, mid);
else update(o << 1 | 1, mid + 1, r);
mx[o] = max(mx[o << 1], mx[o << 1 | 1]);
}
int L, R;
int query(int o, int l, int r) {
if (L <= l && r <= R) {
return mx[o];
}
int mid = (l + r) >> 1;
int ans = 0;
if (L <= mid) cmax(ans, query(o << 1, l, mid));
if (mid < R) cmax(ans, query(o << 1 | 1, mid + 1, r));
return ans;
}
int bstl() {
int l = 0, r = R - 1;
int ans = R, mid;
while (l <= r) {
mid = (l + r) >> 1;
L = mid;
if (query(1, 0, n - 1) <= V) {
ans = mid;
r = mid - 1;
} else {
l = mid + 1;
}
}
return ans;
}
int bstr() {
int l = L + 1, r = n - 1;
int ans = L, mid;
while (l <= r) {
mid = (l + r) >> 1;
R = mid;
if (query(1, 0, n - 1) <= V) {
ans = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
return ans;
}
};
void solve() {
int n; cin >> n;
vector<int> p(n), c(n);
for (int& i : p) cin >> i;
for (int& i : c) cin >> i;
vector<int> pos(n + 1);
for (int i = 0; i < n; i++) {
pos[p[i]] = i;
}
vector<vector<int>> col(n + 1);
vector<vector<int>> colp(n + 1);
for (int i = 0; i < n; i++) {
col[c[i]].push_back(i);
colp[c[i]].push_back(p[i]);
}
SegT segt(n);
for (int i = 0; i < n; i++) {
segt.P = i;
segt.V = p[i];
segt.update(1, 0, n - 1);
}
vector<int> l(n), r(n);
for (int ii = 1; ii <= n; ii++) {
if (col[ii].empty()) continue;
for (int i : col[ii]) {
segt.P = i;
segt.V = 0;
segt.update(1, 0, n - 1);
}
for (int i : col[ii]) {
segt.R = i;
segt.V = p[i];
l[i] = segt.bstl();
segt.L = i;
r[i] = segt.bstr();
}
for (int i : col[ii]) {
segt.P = i;
segt.V = p[i];
segt.update(1, 0, n - 1);
}
}
// for (int i = 0; i < n; i++) {
// cerr << i << ' ' << l[i] << ' ' << r[i] << '\n';
// }
Mint ans = 1;
BIT<int> bit(n);
for (int ii = 1; ii <= n; ii++) {
if (colp[ii].empty()) continue;
sort(colp[ii].rbegin(), colp[ii].rend());
for (int i : colp[ii]) {
i = pos[i];
bit.add(i + 1, 1);
ans *= bit.sum(l[i] + 1, r[i] + 1);
}
for (int i : colp[ii]) {
i = pos[i];
bit.add(i + 1, -1);
}
}
cout << ans << '\n';
}
int main() {
int t = 1;
cin >> t;
while (t--) {
solve();
}
return 0;
}
数据生成
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
inline int rndi() { return rnd(); }
inline int rndi(int n) { return rnd()%n; }
inline int rndi(int l, int r) { return rnd() % (r - l + 1) + l; }
inline LL rndll() { return rnd_64(); }
inline LL rndll(LL n) { return rnd_64() % n; }
inline LL rndll(LL l, LL r) { return rnd_64() % (r - l + 1) + l; }
inline char rndc() { return rnd() % 26 + 'a'; }
inline char rndC() { return rnd() % 26 + 'A'; }
inline char rnddig() { return rnd() % 10 + '0'; }
inline char rndcha() {
int v = rnd() % 52;
return v < 26 ? v + 'a' : (v - 26 + 'A');
}
template<typename T>
void shuffle(vector<T>& vec) { shuffle(vec.begin(), vec.end(), rnd); }
vector<int> rnd_permu(int n, int from = 0) {
vector<int> vec(n);
iota(vec.begin(), vec.end(), from);
shuffle(vec);
return vec;
}
void solve() {
int n = 5; cout << n << '\n';
auto p = rnd_permu(n, 1);
for (int i : p) cout << i << ' ';
cout << '\n';
for (int i = 0; i < n; i++) {
cout << rndi(1, n) << ' ';
}
cout << '\n';
}
int main() {
int t = 1; cout << t << '\n';
while (t--) {
solve();
}
return 0;
}
暴力
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) {
return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) {
return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
void solve() {
int n; cin >> n;
vector<int> p(n), c(n);
for (int& i : p) cin >> i;
for (int& i : c) cin >> i;
auto q = p;
sort(q.begin(), q.end());
auto getcol = [&](const vector<int>& a) -> vector<vector<int>> {
vector<vector<int>> ans;
for (int i = 1; i <= n; i++) {
vector<int> cur;
for (int j = 0; j < n; j++) {
if (a[j] >= i) {
cur.push_back(c[j]);
}
}
ans.push_back(cur);
}
return ans;
};
auto colp = getcol(p);
int ans = 0;
do {
if (getcol(q) == colp) ans++;
} while (next_permutation(q.begin(), q.end()));
cout << ans << '\n';
}
int main() {
int t = 1;
cin >> t;
while (t--) {
solve();
}
return 0;
}
一组有启发性的小数据
1
5
5 3 1 4 2
5 2 3 5 5