Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 3192 Accepted Submission(s): 371
Problem DescriptionYou are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k.
Can you find the number of weak pairs in the tree?
InputThere are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018 OutputFor each test case, print a single integer on a single line denoting the number of weak pairs in the tree. Sample Input12 31 21 2 Sample Output1 题意:给你一颗根树,求有多少点对(u,v) u!=v满足u是v的祖先且点权au*av<=k思路:问题转化一下,就是求对于每一个点u,以该点为根的子树下,有多少个点v的权值是小于等于(k/au + 1); 由于是子树的问题,那么可以想到的是先求出一个dfs序,将问题转化为区间查询的问题,那么问题就是,对于每个点u,在区间[st[u]+1,ed[u]]有多少个值是小于等于(k/au + 1); 求一个区间有多少个数小于莫个数G的数可以用分块实现,就是完整的块二分,两边的块暴力, 复杂度nsqrt(n) 坑点:一直以为1就是root。。。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <cmath>
#include <queue>
#include <algorithm>
#include <stack>
#include <queue>
#include <map>
#include <set>
#include <vector>
#include <cstdlib>
using namespace std;
typedef long long ll;
const int maxn = 3e5 + ;
struct Edge {
int to, nex;
}e[maxn];
int n;
ll k;
ll a[maxn];
int root;
int head[maxn], tot;
void init() {
memset(head, -, sizeof head);
tot = ;
}
void add(int u, int v) {
e[tot].to = v;
e[tot].nex = head[u];
head[u] = tot++;
}
ll id[maxn];
int flag[maxn];
void input() {
scanf("%d%I64d", &n, &k);
memset(flag, , sizeof flag);
for(int i = ; i < n; ++i) scanf("%I64d", &a[i]);
int u, v;
for(int i = ; i < n; ++i) {
scanf("%d%d", &u, &v);
u--; v--;
flag[v] = ;
add(u, v);
}
for(int i = ; i < n; ++i) if(flag[i] == ) { root = i; break; }
}
int st[maxn], ed[maxn], tim;
void dfs(int u) {
st[u] = ++tim;
id[tim] = a[u];
for(int i = head[u]; ~i; i = e[i].nex) {
dfs(e[i].to);
}
ed[u] = tim;
}const int SIZE = ;
ll block[maxn / SIZE + ][SIZE + ];
void init2() {
int b = , j = ;
for(int i = ; i < n; ++i) {
block[b][j] = id[i];
if(++j == SIZE) { b++; j = ; }
}
for(int i = ; i < b; ++i) sort(block[i], block[i] + SIZE);
if(j) sort(block[b], block[b] + j);
}int query(int L, int R, ll v) {
int lb = L / SIZE, rb = R / SIZE;
int k = ;
if(lb == rb) {
for(int i = L; i <= R; ++i) if(id[i] < v) k++;
} else {
for(int i = L; i < (lb + ) * SIZE; ++i) if(id[i] < v) k++;
for(int i = rb * SIZE; i <= R; ++i) if(id[i] < v) k++;
for(int b = lb + ; b < rb; ++b) {
k += lower_bound(block[b], block[b] + SIZE, v) - block[b];
}
}
return k;
}
void solve() {
tim = -;
dfs(root);
init2();
ll ans = ;
for(int i = ; i < n; ++i) {
if(st[i] == ed[i]) continue;
if(a[i] == ) { ans += (ed[i] - st[i]); continue; }
ll v = k / a[i] + ;
ans += query(st[i]+, ed[i], v); }
printf("%I64d\n", ans);
}
int main() {
#ifdef LOCAL
freopen("in", "r", stdin);
#endif
int cas;
while(~scanf("%d", &cas)) {
//int cas;
while(cas --) {
init();
input();
solve();
}
}
return ;
}