https://codeforces.com/contest/1997/problem/E
题目大意
给定 $n\ (\le 2 \cdot 10^5)$ 长的数组和 $q\ (\le 2 \cdot 10^5)$ 个询问。数组中元素 $a_i\ (\le 2 \cdot 10^5)$。对于参数 $k$,我们依次在数组中标记 $k$ 个不小于 $1$ 的数,之后从最后标记的下标开始,再标记 $k$ 个不小于 $2$ 的数,以此类推。
每个询问给出 $i\ (1 \le i \le n)$ 和 $x\ (1 \le x \le n)$,问 $i$ 位置在参数 $k = x$ 时是否被标记。
简要题解
先考虑如何判断一个位置是否被标记。我们考虑每次标记的 $k$ 个的一段,假设这段是第 $p$ 段,则这一段中 $a_i \ge p$ 的 $i$ 会被标记,而其他的 $i$ 不会。
因为 $x$ 不超过 $n$,我们考虑枚举 $x$。(实际上超过了也没关系。$x \ge n$ 时所有位置都会被标记。)
注意到,每一段至少都有 $x$ 长,也就是说最多会分 $\lfloor \frac{n}{x} \rfloor$ 段,而对于所有的 $x$ 最多有 $O(n \log n)$ 段。因此如果我们需要一个办法能够快速的处理段。
考虑我们已经有了参数为 $x$ 第 $p$ 段的末尾位置 $pos$(最后一个标记元素的下标),此时相当于我们要找 $[pos + 1, n]$ 中最小的位置 $r$ 使得 $\sum_{pos + 1 \le i \le r}[a_i \ge (p + 1)] == x$。这个位置显然是可以二分的,那么现在问题就是有什么数据结构可以快速的完成这个求和操作。线段树可以维护所有 $\ge p + 1$ 的位置。
每次固定 $x$ 让 $p$ 变化那么线段树中的合法数字是不容易维护的(当然也可以可持久化)。 更简单的做法是固定 $p$,这样每次 $p$ 变化再更新线段树即可。可以用一个集合维护所有还没有分完的 $x$。 对于这些 $x$ 处理出段,然后处理段中的询问。
最后注意到线段树上可以直接做二分。至此我们得到了一个 $O(n \log^2 n)$ 的优秀做法。
特别注意,这个题是多组数据,而每组数据 $a_i$ 的范围和 $n$ 无关。(但实际上 $> n$ 的数据是没有意义的,可以完全转化到 $= n$)
复杂度
$T$:$O(n \log^2 n)$
$S$:$O(n)$
代码实现
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) {
return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) {
return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
struct SegT {
int n;
vector<int> sum;
SegT(int n) : n(n), sum(n << 2) {}
int P, V;
void update(int o, int l, int r) {
sum[o] += V;
if (l == r) return;
int mid = (l + r) >> 1;
if (P <= mid) update(o << 1, l, mid);
else update(o << 1 | 1, mid + 1, r);
}
int L, S;
pair<int, int> query(int o, int l, int r) {
if (r < L) return {-1, 0};
if (L <= l && sum[o] < S) return {-1, sum[o]};
if (l == r) {
return {l, 1};
}
int mid = (l + r) >> 1;
PII ans = query(o << 1, l, mid);
if (ans.first != -1) return ans;
S -= ans.second;
return query(o << 1 | 1, mid + 1, r);
}
};
void solve() {
int n, q; cin >> n >> q;
vector<int> a(n);
for (int& i : a) {
cin >> i;
if (i > n) i = n;
}
// map<int, vector<int>> pos;
vector<vector<int>> pos(n + 1);
for (int i = 0; i < n; i++) {
pos[a[i]].push_back(i);
}
vector<vector<PII>> qs(n + 1);
int j, x;
for (int i = 0; i < q; i++) {
cin >> j >> x;
qs[x].push_back({j - 1, i});
}
for (int i = 1; i <= n; i++) {
sort(qs[i].begin(), qs[i].end());
}
vector<int> nxt(n + 1);
vector<int> valid, tmp;
vector<int> qi(n + 1, 0);
for (int i = 1; i <= n; i++) valid.push_back(i);
SegT segt(n);
for (int i = 0; i < n; i++) {
segt.P = i;
segt.V = 1;
segt.update(1, 0, n - 1);
}
vector<int> ans(q);
for (int p = 1; p <= n; p++) {
for (int i : pos[p - 1]) {
segt.P = i;
segt.V = -1;
segt.update(1, 0, n - 1);
}
tmp.clear();
for (int x : valid) {
segt.L = nxt[x];
segt.S = x;
auto [r, _] = segt.query(1, 0, n - 1);
// cerr << x << ' ' << p << ' ' << nxt[x] << ' ' << r << endl;
if (r == -1) r = n - 1;
while (qi[x] < (int)qs[x].size() && qs[x][qi[x]].first <= r) {
if (a[qs[x][qi[x]].first] >= p) ans[qs[x][qi[x]].second] = true;
// cerr << " " << qs[x][qi[x]].first << ' ' << qs[x][qi[x]].second << ' ' << (ans[qs[x][qi[x]].second]) << endl;
qi[x]++;
}
if (r + 1 < n) {
nxt[x] = r + 1;
tmp.push_back(x);
}
}
valid = tmp;
}
for (int i = 0; i < q; i++) {
cout << (ans[i] ? "YES" : "NO") << '\n';
}
}
int main() {
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}
Next: D. Med-imize - Codeforces Round 963 (Div. 2)