week3 牛客6 C
推广到n项,((sum(ai[i,1,n]))^3=sum(a(i)^3)+6(1≤i<j≤n∑(I[i]2I[j]+I[i]I[j]2)+6*(1<=i<j<k<=n)(sum[i][j][k])合并同类,因为 所有[i]^2,[i]^3和[i]的值相同,所以最后f(p)^3= i=1∑nI[i]+61≤i<j≤n∑I[i]I[j]+61≤i<j<k≤n∑I[i]I[j]I[k]对于su
对于一个排列 P,定义 f(P) 如下:(P为单调增栈) 给定一个整数 n,求所有长度为 n 的排列 P 0 的 (f(P 0 ))3次方 的和,对 998244353 取模。 Input 本题有多组输入数据。 第一行输入一个正整数 T(1 ≤ T ≤ 105 ),表示输入数据组数。 接下来的每组输入数据,输入一个正整数 n(1 ≤ n ≤ 5 · 105 )。 Output 对于每组输入数据,输出一行一个整数,表示答案。 Example standard input standard output 2 3 3741 53 805156151
题意 给定 T 次询问,每次给定 n,求所有 n! 种排列执行递增单调栈后,单 调栈的大小的三次方的和。 1 ≤ T ≤ 105,1 ≤ n ≤ 5 · 105。
现在对于每一种排列P,f(p)的值体现为从左到右数,当前数为最小值的个数,形式化的,f(p)=(sumi~n)[i是从左往右数的最小值?1:0],那么我们要求的是(sum(p的所有排列))(f(p))^3;
我们可以看出,有这个三次方在会让问题变得很复杂,我们需要降次;
因为[i]的值只能是0/1,所有[i]^2,[i]^3和[i]的值相同,
对于(a+b+c)^3=a3+b3+c3+3a2b+3a2c+3b2a+3b2c+3c2a+3c2b+6abc
推广到n项,((sum(ai[i,1,n]))^3=sum(a(i)^3)+6(1≤i<j≤n∑(I[i]2I[j]+I[i]I[j]2)+6*(1<=i<j<k<=n)(sum[i][j][k])
//所以为什么会有三个项,因为三次方会让至多三个项同时出现
合并同类,因为 所有[i]^2,[i]^3和[i]的值相同,所以最后f(p)^3= i=1∑nI[i]+61≤i<j≤n∑I[i]I[j]+61≤i<j<k≤n∑I[i]I[j]I[k]
那么答案就是所有可能排列的这个式子的累加
我们交换求和顺序
∑f(P)3=i∑P∑IP[i]+6i<j∑P∑IP[i]IP[j]+6i<j<k∑P∑IP[i]IP[j]IP[k]
现在我们来思考一个问题,在P的所有排列中i是从左到右的最小值的情况有多少种?
我们可以从概率的角度思考这个问题,首先总共有n!种排列,在前1~i个数中i排在最左边的概率是1/i,所有答案就是n!/i
那么加入j呢?
j在前j个数中j最小&&i在前i个数中最小------n!/(i*j)
再加入k,同理-------n!/(i*j*k)
于是这道题就变成了
P∑f(P)3=n!i∑i1+6i<j∑ij1+6i<j<k∑ijk1
根据题目的数据范围,很显然我们需要一个O(n)的预处理和一个O(1)的查询,
对于sum(1/i),很显然h1(n)=h1(n-1)+1/n;
对于sum(1/(i*j)),考虑j=n,h2(n)=h2(n-1)+h1(n-1)/n
sum(1/(i*j*k))同理,h3(n)=h3(n-1)+h2(n-1)/n;
初始值赋全0;
于是本题的整体思路就形成了,本题再nk6作为签到题还是有点复杂了hhh
以下是代码:#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int M=500000;
typedef long long ll;
vector<ll>inv(M+1,0),h1(M+1,0),h2(M+1,0),h3(M+1,0),f(M+1,1),ans(M+1,0);
void get(){
inv[1]=1;
for(int i=2;i<=M;i++){
inv[i]=(mod-(mod/i)*inv[mod%i]%mod)%mod;
}
for(int i=1;i<=M;i++){
h1[i]=(h1[i-1]+inv[i])%mod;
h2[i]=(h2[i-1]+inv[i]*inv[i]%mod)%mod;
h3[i]=(h3[i-1]+inv[i]*inv[i]%mod*inv[i]%mod)%mod;
}
for(int i=1;i<=M;i++){
f[i]=f[i-1]*i%mod;
}
for(int n=1;n<=M;n++){
ll t=(h1[n]*h1[n]%mod*h1[n]%mod);
t=(t+3*h1[n]%mod*h1[n]%mod)%mod;
t=(t+h1[n])%mod;
t=(t-3*h2[n])%mod;
t=(t-3*h1[n]%mod*h2[n]%mod)%mod;
t=(t+2*h3[n])%mod;
if(t<0)t+=mod;
ans[n]=f[n]*t%mod;
}
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
get();
int t;
cin>>t;
while(t--){
int num;
cin>>num;
cout<<ans[num]<<"\n";
}
}
官方题解:
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
const int N = 5e5 + 1;
vector<mint> inv(N), fac(N);
vector<mint> sum1(N), sum2(N), sum3(N);
fac[0] = inv[0] = 1;
for (int i = 1; i < N; i++) {
inv[i] = mint(1) / i;
fac[i] = fac[i - 1] * i;
sum1[i] = sum1[i - 1] + inv[i];
sum2[i] = sum2[i - 1] + sum1[i - 1] * inv[i];
sum3[i] = sum3[i - 1] + sum2[i - 1] * inv[i];
}
int t;
cin >> t;
while (t--) {
int n;
cin >> n;
auto ans = sum1[n] + sum2[n] * 6 + sum3[n] * 6;
ans *= fac[n];
cout << ans << '\n';
}
}
更多推荐


所有评论(0)