1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
| mt19937 rnd(20130825); template<int N> class treap { private: struct node { int l, r, sz, v, w, pr; } tr[N]; int cnt, root; void update(int rt) { tr[rt].sz = tr[tr[rt].l].sz + tr[tr[rt].r].sz + tr[rt].w; } void lturn(int &rt) { int x = tr[rt].r; tr[rt].r = tr[x].l; tr[x].l = rt; tr[x].sz = tr[rt].sz; update(rt); rt = x; } void rturn(int &rt) { int x = tr[rt].l; tr[rt].l = tr[x].r; tr[x].r = rt; tr[x].sz = tr[rt].sz; update(rt); rt = x; } bool _insert(int &rt, int v) { if (!rt) return tr[rt = ++ cnt] = {0, 0, 1, v, 1, (int)rnd()}, 1; tr[rt].sz ++; if (tr[rt].v == v) return tr[rt].w ++, 0; else if (v < tr[rt].v) { bool res = _insert(tr[rt].l, v); if (tr[tr[rt].l].pr < tr[rt].pr) rturn(rt); return res; } else { bool res = _insert(tr[rt].r, v); if (tr[tr[rt].r].pr < tr[rt].pr) lturn(rt); return res; } } bool _erase(int &rt, int v) { if (!rt) return 0; if (tr[rt].v == v) { if (tr[rt].w > 1) return tr[rt].sz --, tr[rt].w --, 1; if (!(tr[rt].l && tr[rt].r)) rt = tr[rt].l | tr[rt].r; else if (tr[tr[rt].l].pr < tr[tr[rt].r].pr) return rturn(rt), _erase(rt, v); else return lturn(rt), _erase(rt, v); } else { if (v < tr[rt].v) {if (_erase(tr[rt].l, v)) tr[rt].sz --;} else {if (_erase(tr[rt].r, v)) tr[rt].sz --;} } return 1; } int _prev(int rt, int v) { if (!rt) return -INF; if (tr[rt].v < v) return max(tr[rt].v, _prev(tr[rt].r, v)); else return _prev(tr[rt].l, v); } int _next(int rt, int v) { if (!rt) return INF; if (tr[rt].v > v) return min(tr[rt].v, _next(tr[rt].l, v)); else return _next(tr[rt].r, v); } int _rank(int rt, int v) { if (!rt) return 1; if (tr[rt].v == v) return tr[tr[rt].l].sz + 1; else if (v < tr[rt].v) return _rank(tr[rt].l, v); else return tr[tr[rt].l].sz + tr[rt].w + _rank(tr[rt].r, v); } int _kth(int rt, int k) { if (!rt) return -INF; if (k <= tr[tr[rt].l].sz) return _kth(tr[rt].l, k); else if (k > tr[tr[rt].l].sz + tr[rt].w) return _kth(tr[rt].r, k - tr[tr[rt].l].sz - tr[rt].w); else return tr[rt].v; } public: int size() {return tr[root].sz;} bool insert(int v) {return _insert(root, v);} bool erase(int v) {return _erase(root, v);} int prev(int v) {return _prev(root, v);} int next(int v) {return _next(root, v);} int rank(int v) {return _rank(root, v);} int kth(int k) {return _kth(root, k);} };
|