#include <bits/stdc++.h>
#define int long long

using namespace std;

const int N = 300000;
const int inf = (int)0x3f3f3f3f;

int T, n, ans, sum, rt, cnt;
int sz[N], son[N][2], f[N][20];
vector<int>v[N];

void update(int);
void dfs(int, int);
void get(int);
void dp(int, int);

signed main() {
	cin >> T;
	while(T--) {
		for(int i = 1; i <= n; i++) v[i].clear();
		for(int i = 1; i <= n; i++) {
			son[i][0] = son[i][1] = 0;
		}
		sum = 0;
		cin >> n;
		for(int i = 1; i < n; i++) {
			int x, y;
			cin >> x >> y;
			v[x].push_back(y);
			v[y].push_back(x);
		}
		dfs(1, 0);
		dp(1, 0);
		cout << sum << "\n";
	}
	return 0;
}
void update(int x) {
	f[x][0] = son[x][0];
	for (int i = 1; i <= 19; i++) {
		f[x][i] = f[f[x][i - 1]][i - 1];
	}
}
void dfs(int x, int frt) {
	sz[x] = 1;
	for(int i = 0; i < v[x].size(); i++) {
		int u = v[x][i];
		if(u == frt) continue;
		dfs(u, x);
		sz[x] += sz[u];
		if (sz[u] > sz[son[x][0]]) {
			son[x][1] = son[x][0];
			son[x][0] = u;
		}
		else if (sz[u] > sz[son[x][1]]) 
			son[x][1] = u;
	}
	update(x);
}
void get(int x) {
	int all = sz[x];
	for (int i = 19; i >= 0; i--) {
		if (sz[f[x][i]] > (all >> 1)){
			x = f[x][i];
		}
	}
	sum += x;
	if(all - sz[son[x][0]] <= (all >> 1)) ans += son[x][0];
}
void dp(int x, int fa) {
	int t1 = son[x][0], t2 = sz[x];
	for(int i = 0; i < v[x].size(); i++) {
		int u = v[x][i];
		if(u == fa) continue;
		if(u == t1) {
			son[x][0] = sz[son[x][1]] > n - t2 ? son[x][1] : fa;
		} else {
			son[x][0] = sz[t1] > n - t2 ? t1 : fa;
		}
		update(x);
		sz[x] = n - sz[u];
		get(x);
		get(u);
		dp(u, x);
	}
	son[x][0] = t1;
	sz[x] = t2;
	update(x);
}



1 条评论

  • @ 2025-7-17 17:03:41
    #include<bits/stdc++.h>
    using namespace std; 
    long long ans;
    const int N = (int)300000;
    const int inf = (int)0x3f3f3f;
    int t, n, son[N][2], siz[N], p[N][20], idx;
    vector<int> g[N];
    void update(int u) {
    	p[u][0] = son[u][0];
    	for (int i = 1; i < 20; i++){
    		p[u][i] = p[p[u][i - 1]][i - 1];
    	}
    }
    
    void dfs(int u, int fa){
    	siz[u] = 1;
    	for (int i =0; i < g[u].size(); i++){
    		int v = g[u][i];
    		if(v == fa){
    			continue;
    		}
    		dfs(v, u);
    		siz[u] += siz[v];
    		if (siz[v] > siz[son[u][0]]) {
    			son[u][1] = son[u][0];
    			son[u][0] = v;
    		} else if(siz[v] > siz[son[u][1]]) {
    			son[u][1] = v;
    		}
    	}
    	update(u);
    }
    
    void get(int u) {
    	int rt = u;
    	for (int i = 19; i >= 0; i--){
    		if (siz[p[u][i]] > siz[rt] / 2){
    			u = p[u][i];
    		}
    	}
    	//if (siz[rt] - siz[u] <= all / 2){
    		ans += u;
    		//ans += son[u][0];
    		if (siz[rt] - siz[son[u][0]] <= siz[rt] / 2){
    			ans += son[u][0];
    		}
    	//}
    	
    }
    void dp(int u, int fa) {
    	int x = son[u][0], y = siz[u];
    	for (int i = 0; i < g[u].size(); i++) {
    		int v = g[u][i];
    		if(v == fa) {
    			continue;
    		}
    		if(v == x) {
    			son[u][0] = siz[son[u][1]] > n - y ? son[u][1] : fa;
    		} else {
    			son[u][0] = siz[x] > n - y ? x : fa;
    		}
    		update(u);
    		siz[u] = n - siz[v];
    		get(u);
    		get(v);
    		dp(v, u);
    	}
    	son[u][0] = x;
    	siz[u] = y;
    	update(u);
    }
    int main() {
    	scanf("%d", &t);
    	while ( t-- ) {
    		for (int i = 1; i <= n; i++) {
    			g[i].clear();
    			son[i][0] = son[i][1] = 0;
    		}
    		idx = ans = 0;
    		scanf("%d", &n);
    		for(int i = 1, u, v; i < n; i++) {
    			scanf("%d %d", &u, &v);
    			g[u].push_back(v);
    			g[v].push_back(u);
    		}
    		dfs(1, 0);
    		dp(1, 0);
    		printf("%lld\n", ans);
    	}
    	return 0;
    }
    
    
    • 1