Codeforces 1336F Journey

一句话题意:给你一棵n个点的树,和树上的m条路径。问有多少对路径的交长度大于k?

n m 150000, 4s

我实现的做法是题解做法,代码比较长。但是我还是比网上的一个同学短了100行。。

大概合并的线段树还是可以使用pbds里面的平衡树来代替,但是我写代码之前并不知道……

写完之后去读了最短的代码(代码长度0.5倍),发现是中国选手。没读懂,问出题人,出题人也读不懂。。

找到他的QQ就问了一下他,然后发现了一个很妙的做法,然后发现好多选手都是写的这个做法,只有我太憨了去实现题解。。。下面描述一下那个做法。

首先进行一步转化。

一个树上的路径,[|路径|>=k] = #长度为k的子段-#长度为k+1的子段

因此只需要对树上的所有长度为k或者k+1的路径求一下被覆盖了多少次,就得到答案。比如被覆盖了c次,那么贡献的绝对值就是c(c-1)/2,符号取决于长度是k(正)还是k+1(负)。

首先进行重链剖分,然后对于一条路径,长度为k的段大概是这样的:

竖直的和拐弯的段

竖直的覆盖的段上的贡献可以直接使用剖分上的差分来做。现在问题就出在拐弯的段上。

拐弯的段的处理很巧妙,考虑我们把整个区间拆分成重链上。

一个1到10的链

然后考虑任何一个长度为k的跨越lca(4号点)的子段,我们采用他开始节点所在重链、终止节点所在重链、开始节点的深度来差分记录,开一个数组套map套map就行了,复杂度也是两个log……太妙了。

下面是我的题解做法代码

#pragma comment(linker, "/stack:200000000")
#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
#pragma GCC optimize("unroll-loops")
#include <bits/stdc++.h>


using namespace std;
#define set0(x) memset(x,0,sizeof(x))
#define F first
#define S second
#define PB push_back
#define MP make_pair
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define trav(a, x) for(auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()

typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef vector<int> VI;
template<typename T> void read(T &x){
  x = 0;char ch = getchar();ll f = 1;
  while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
  while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}
template<typename T, typename... Args> void read(T &first, Args& ... args) {
  read(first);
  read(args...);
}


const int N = 150050;
int n,m,k,s[N],t[N],lc[N],u,v;
vector<int> G[N],G2[N],T[N],oT[N];

int fa[N][19],d[N],dfn[N],sz[N],tim = 0;
void dfs0(int num,int cf = 0){
  dfn[num] = ++tim;
  fa[num][0] = cf;
  sz[num] = 1;
  for(int j=0;fa[num][j];j++)fa[num][j+1] = fa[fa[num][j]][j];
  for(auto ct:G[num]){
    if(ct==cf)continue;
    d[ct] = d[num]+1;
    dfs0(ct,num);
    sz[num]+=sz[ct];
  }
}

inline int lca(int u,int v){
  if(d[u]<d[v])swap(u,v);
  for(int i=18;i>=0;i--)if(fa[u][i] && d[fa[u][i]]>=d[v])u = fa[u][i];
  for(int i=18;i>=0;i--)if(fa[u][i]!=fa[v][i])u = fa[u][i],v = fa[v][i];
  return u == v?u:fa[u][0];
}

inline int anc(int u,int x){
  for(int i=18;i>=0;i--) if(fa[u][i] && d[fa[u][i]]>=x)u = fa[u][i];
  return u;
}

struct Fenwick{
  int s[N];
  vector<pii> log;
  inline void update(int pos, int dif) {
    log.emplace_back(pos,dif);
    while(pos<N){
      s[pos]+=dif;
      pos+=pos&(-pos);
    }
  }
  void reset(){
    int pos,dif;
    for(auto ct:log){
      pos = ct.F;
      dif = -ct.S;
      while(pos<N){
        s[pos]+=dif;
        pos+=pos&(-pos);
      }
    }
    log.clear();
  }
  inline int query(int pos) {
    int res = 0;
    while(pos){
      res+=s[pos];
      pos-=pos&(-pos);
    }
    return res;
  }
}A;

vector<pii> F[N],F2[N];

struct node{
  int ls = 0,rs = 0;
  int val = 0;
}nds[10000010];
int cnt = 0,rts[N];
#define mid ((cl+cr)/2)
int query(int id,int l,int r,int cl = 0,int cr = n){
  if(id == 0 || (l<=cl && cr<=r))return nds[id].val;
  return ((l<=mid)?query(nds[id].ls,l,r,cl,mid):0)+((r>mid)?query(nds[id].rs,l,r,mid+1,cr):0);
}
void add(int &id,int x,int cl = 0,int cr = n){
  if(!id)id = ++cnt;nds[id].val++;if(cl == cr)return;
  if(x<=mid)
    add(nds[id].ls,x,cl,mid);
  else
    add(nds[id].rs,x,mid+1,cr);
}
int merge(int a,int b){
  if(!a || !b)return a|b;
  nds[a].val+=nds[b].val;
  nds[a].ls = merge(nds[a].ls,nds[b].ls);
  nds[a].rs = merge(nds[a].rs,nds[b].rs);
  return a;
}

void reset(){
  memset(nds,0,sizeof(nds[0])*(cnt+1));
  cnt = 0;
}
ll ans = 0;

int crt = 0;
void dfs1(int num){
  int tgt;
  auto add_it = [&](int ct){
    if(d[num]+d[ct]-2*d[crt]>=k){
      if(d[num]>=d[crt]+k){
        tgt = anc(num,d[num]-k+1);
        ans+=nds[rts[num]].val;
        ans-=query(rts[num],dfn[tgt],dfn[tgt]+sz[tgt]-1);
      }else{
        tgt = anc(ct,d[crt]+k-(d[num]-d[crt]));
        ans +=query(rts[num],dfn[tgt],dfn[tgt]+sz[tgt]-1);
      }
    }
  };
  int mx = -1,ms = -1;
  for(auto ech:G2[num]){
    dfs1(ech);
    if((int)T[ech].size()>mx){
      mx = T[ech].size();
      ms = ech;
    }
  }
  if(ms!=-1)rts[num] = merge(rts[num],rts[ms]);
  for(auto ct:T[num]){
    add_it(ct);
    add(rts[num],dfn[ct]);
  }
  if(ms!=-1){
    //cout<<ms<<endl;
    swap(T[num],T[ms]);
    for(auto ct:T[ms])T[num].PB(ct);
  }
  for(auto ech:G2[num]){
    if(ech == ms){
      //cout<<"JMP"<<endl;
      continue;
    }
    for(auto ct:T[ech]){
      add_it(ct);
      T[num].PB(ct);
    }
    rts[num] = merge(rts[num],rts[ech]);
  }
}

int main() {
  read(n,m,k);
  for(int i=1;i<n;i++){
    read(u,v);
    G[u].PB(v);
    G[v].PB(u);
  }
  d[1] = 1;
  dfs0(1);
  for(int i=0;i<m;i++){
    read(s[i],t[i]);
    if(dfn[s[i]]>dfn[t[i]])swap(s[i],t[i]);
    lc[i] = lca(s[i],t[i]);
    F[d[lc[i]]].emplace_back(s[i],t[i]);
    F2[lc[i]].emplace_back(s[i],t[i]);
  }
  for(int i=n;i>=1;i--){
    for(auto ct:F[i]){
      int l = ct.F,r = ct.S;
      ans+=A.query(dfn[l]);
      ans+=A.query(dfn[r]);
    }
    for(auto ct:F[i]){
      int l = ct.F,r = ct.S;
      if(d[l]>=i+k){
        l = anc(l,i+k);
        A.update(dfn[l],1);
        A.update(dfn[l]+sz[l],-1);
      }
      if(d[r]>=i+k){
        r = anc(r,i+k);
        A.update(dfn[r],1);
        A.update(dfn[r]+sz[r],-1);
      }
    }
  }
  A.reset();
  auto cmp = [&](int a,int b)->bool{return dfn[a]<dfn[b];};
  for(int i=1;i<=n;i++){
    crt = i;
    vector<int> cu,scope;
    scope.PB(i);
    for(auto ct:F2[i]){
      cu.PB(ct.F);
      T[ct.F].PB(ct.S);
      oT[ct.F].PB(ct.S);
    }
    sort(all(cu),cmp);
    cu.erase(unique(all(cu)),cu.end());
    vector<int> stk;
    stk.PB(i);
    rts[i] = ++cnt;
    auto pb = [&](int a){
      rts[a] = ++cnt;
      scope.PB(a);
      stk.PB(a);
    };
    for(auto ct:cu){
      int dd = lca(ct,stk.back());
      while(d[stk.back()]>d[dd]){
        if(d[stk[stk.size()-2]]>d[dd]) G2[stk[stk.size()-2]].PB(stk.back());
        else G2[dd].PB(stk.back());
        stk.pop_back();
      }
      if(d[stk.back()]<d[dd]) pb(dd);
      if(d[stk.back()]<d[ct]) pb(ct);
    }
    while(stk.size()>=2){
      G2[stk[stk.size()-2]].PB(stk.back());
      stk.pop_back();
    }
    dfs1(i);
    for(auto ct:cu){
      ans+=1ll*A.query(dfn[ct])*oT[ct].size();
      for(auto ed:oT[ct]){
        if(d[ed]>=d[i]+k){
          int cc = anc(ed,d[i]+k);
          A.update(dfn[cc],1);
          A.update(dfn[cc]+sz[cc],-1);
        }
      }
    }
    A.reset();
    reset();
    for(auto ct:scope){
      G2[ct].clear();
      rts[ct] = 0;
      T[ct].clear();
      oT[ct].clear();
    }
  }
  cout<<ans<<endl;
  return 0;
}

下面是高水平选手的nb做法代码

#include<bits/stdc++.h>
 
using namespace std;
 
typedef long long LL;
#define N 300000
 
int n,m,k,a[N][2],fa[N],sz[N],top[N],d[N],prf[N],f1[N],q1[N],q2[N];
vector<int> g[N],b[N];
map<int,map<int,int> > f2[N];
LL ans;
 
void dfs1(int u){
	d[u]=d[fa[u]]+1;
	sz[u]=1;
	for (int v:g[u])
		if (v!=fa[u]){
			fa[v]=u;
			dfs1(v);
			sz[u]+=sz[v];
		}
}
 
void dfs2(int u){
	if (!top[u]) top[u]=u;
	b[top[u]].push_back(u);
	int t=0;
	for (int v:g[u])
		if (v!=fa[u]&&sz[v]>sz[t]) t=v;
	if (!t) return;
	prf[u]=t; top[t]=top[u]; dfs2(t);
	for (int v:g[u])
		if (v!=fa[u]&&v!=t) dfs2(v);
}
 
int lca(int x,int y){
	for (;top[x]!=top[y];x=fa[top[x]])
		if (d[top[x]]<d[top[y]]) swap(x,y);
	return d[x]<d[y]?x:y;
}
 
int go(int x,int k){
	int len=d[x]-d[top[x]];
	if (len<k) return go(fa[top[x]],k-len-1);
	return b[top[x]][len-k];
}
 
int add1(int x,int y){
	int len=d[x]-d[y];
	if (len<k) return x;
	int z=go(x,len-k+1);
	++f1[x]; --f1[z];
	return z;
}
 
void add2(int u1,int u2,int v1,int v2){
	if (top[u1]>top[v1]){
		swap(u1,v1); swap(u2,v2);
	}
	if (d[u1]>d[u2]) swap(u1,u2);
	++f2[top[u1]][top[v1]][d[u1]];
	--f2[top[u1]][top[v1]][d[u2]+1];
}
 
void add(int x,int y){
	int z=lca(x,y);
	x=add1(x,z); y=add1(y,z);
	int u=x,v=y;
	int n1=0,n2=0;
	for (;top[u]!=top[z];u=fa[top[u]]) q1[++n1]=u;
	for (;top[v]!=top[z];v=fa[top[v]]) q2[++n2]=v;
	if (u!=z) q1[++n1]=u; if (v!=z) q2[++n2]=v;
	for (int i1=1,i2=n2;i1<=n1;++i1){
		int u1=q1[i1],u2=top[u1];
		if (d[u2]<=d[z]) u2=prf[z];
		for (;i2;--i2){
			int v1=q2[i2],v2=top[v1];
			if (d[v2]<=d[z]) v2=prf[z];
			if (d[u1]+d[v1]-d[z]*2<k) continue;
			int len=d[u1]+d[v1]-d[z]*2;
			int w1=go(v1,len-k),w2=0;
			len=d[u2]+d[v1]-d[z]*2;
			if (len>=k){
				w2=go(v1,len-k);
				add2(u1,u2,w1,w2);
				break;
			}
			len=d[u1]+d[v1]-d[z]*2;
			w2=go(u1,len-k);
			add2(u1,w2,w1,v1);
			u1=fa[w2];
		}
	}
}
 
LL C(LL x){return x*(x-1)/2;}
 
void dfs3(int u,int t){
	for (int v:g[u])
		if (v!=fa[u]){
			dfs3(v,t);
			f1[u]+=f1[v];
		}
	ans+=C(f1[u])*t;
}
 
void calc(int t){
	for (int i=1;i<=n;++i){
		for (auto j:f2[i]){
			LL sum=0,lst=0;
			for (auto k:j.second){
				ans+=(k.first-lst)*C(sum)*t;
				lst=k.first;
				sum+=k.second;
			}
		}
		f2[i].clear();
	}
}
 
int main(){
	scanf("%d%d%d",&n,&m,&k);
	for (int i=1;i<n;++i){
		int x,y; scanf("%d%d",&x,&y);
		g[x].push_back(y); g[y].push_back(x);
	}
	dfs1(1); dfs2(1);
	for (int i=1;i<=m;++i){
		scanf("%d%d",a[i]+0,a[i]+1);
		add(a[i][0],a[i][1]);
	}
	dfs3(1,1); calc(1);
	memset(f1,0,sizeof f1);
	++k;
	for (int i=1;i<=m;++i)
		add(a[i][0],a[i][1]);
	dfs3(1,-1); calc(-1);
	printf("%lld\n",ans);
	
	return 0;
}

《Codeforces 1336F Journey》上有1条评论

发表评论

电子邮件地址不会被公开。 必填项已用*标注