予鈴

競プロとか備忘録とか…

CS Academy Array Removal

csacademy.com

問題概要

N要素の数列が与えられ、N要素のsubarrayが与えられる。
N 回連続する数列の和の最大値を出力する。ただし、subarrayの i番目の数字は使えなくなる。

解く前に考えてたこと

  • DP : 不可能そう。使えない数字が出るたびに引くと間に合わない
  • StarrySkyTree:区間がどれほど重複してるか数える必要が出てきて、無理となった。

解法

union-findとsubarrayを逆からみていくと良い。
i番目の数字を使えるとすると、i-1i+1番目の要素とuniteしていく。各集合のparentをkeyとする和の配列sumを持っていると良い。
答えとなりうるのは、i番目の要素を足す前の最大値か、足した後の和となるので、そのあたりはうまいことやる。

#include <bits/stdc++.h>
using ll = long long;
#define int ll
#define INF 1e9
#define EPS 0.0000000001
#define rep(i,n) for(int i=0;i<n;i++)
#define all(in) in.begin(), in.end()
template<class T, class S> void cmin(T &a, const S &b) { if (a > b)a = b; }
template<class T, class S> void cmax(T &a, const S &b) { if (a < b)a = b; }
using namespace std;
#define MAX 9999999
class unionfind {
    vector<int> par, rank, size_ ,sum;
public:
    unionfind(int n) :par(n), rank(n), size_(n, 1),sum(n,0) {
        iota(all(par), 0);
    }
    int find(int x) {
        if (par[x] == x)return x;
        return par[x] = find(par[x]);
    }
    void unite(int x, int y) {
        
        x = find(x), y = find(y);
        if (x == y)return;
        if (rank[x] < rank[y])swap(x, y);
        par[y] = x;
        size_[x] += size_[y];
        if (rank[x] == rank[y])rank[x]++;
    }
    bool same(int x, int y) {
        return find(x) == find(y);
    }
    int size(int x) {
        return size_[find(x)];
    }
};
signed main(){
    int n, maxi = 0;
    cin >> n;
    vector<int> v(n,0),sub(n),ans,sum(2*n,0);
    vector<bool> used(n,false);
    unionfind uni(2*n);
    rep(i,n) cin >> v[i];
    rep(i,n) cin >> sub[i];
    rep(i,n) {
        uni.unite(i,i+n);
        sum[uni.find(i)] = v[i];
    }
    reverse(all(sub));
    rep(i,sub.size()){
        int index = sub[i]-1;
        maxi = max(maxi,v[index]);
        if(index+1 < n){
            if(used[index+1]){
                int val = sum[uni.find(index)]+sum[uni.find(index+1)];
                uni.unite(index, index+1);
                sum[uni.find(index)] =val;
            }
        }
        if(index-1 >=0){
            if(used[index-1]){
                int val = sum[uni.find(index)]+sum[uni.find(index-1)];
                uni.unite(index,index-1);
                sum[uni.find(index)] = val;
                
            }
        }
        cmax(maxi, sum[uni.find(index)]);
        
        ans.push_back(maxi);
        used[index] = true;
    }
    reverse(all(ans));
    rep(i,ans.size()) cout << ans[i] << endl;
}