CodeForces 1128C - Primes and Multiplication 题解

long long 爆的好啊!!

题目链接

我们把要求的式子展开

f(x,1)f(x,2)f(x,n) =g(1,p1)g(1,p2)g(1,pn) g(2,p1)g(2,p2)g(2,pn) g(3,p1)g(3,p2)g(3,pn)  g(n,p1)g(n,p2)g(n,pn)\begin{align*}&f(x, 1) \cdot f(x, 2) \cdot \ldots \cdot f(x, n) \\\ =&g(1,p_1)\cdot g(1,p_2)\cdot \ldots \cdot g(1,p_n) \\\ &g(2,p_1)\cdot g(2,p_2)\cdot \ldots \cdot g(2,p_n) \\\ &g(3,p_1)\cdot g(3,p_2)\cdot \ldots \cdot g(3,p_n) \\\ &\vdots \\\ &g(n,p_1)\cdot g(n,p_2)\cdot \ldots \cdot g(n,p_n)\end{align*}

然后每次计算一列,由于pp是质数,当且仅当n=kpjn=k\cdot p^jg(n,p)=jg(n,p)=j,否则g(n,p)=1g(n,p)=1。由于同一列中pp都是相同的,所以只要计算指数之和就行了。直接分析代码:

ll tmp = it;
ll cnt = 0;
while (tmp <= n) {
    cnt += (n / tmp);
    if (n / tmp < it)
    break;
    tmp *= it;
}
if (cnt == 0)
    continue;
ans = ans * binpow(it, cnt) % mod;

n / tmp的结果就是对于当前的 tmp,1,2,3,,n1,2,3,\ldots,n中有几个可以整除 tmp。 对于1,2,,n1,2,\ldots,n每个数字都被筛过g(n,p)g(n,p)次,所以累加每一次的n / tmp就是指数之和了。注意tmp *= it可能会爆 long long 所以乘之前要先检查一下(做的时候被卡了,直接自闭)。

完整代码:

#include<iostream>
#include<vector>
 
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define fore(i, l, r) for (int i = (int)(l); i <= (int)(r); ++i)
#define ford(i, n) for (int i = (int)(n)-1; i >= 0; --i)
#define pb push_back
#define ms(a, x) memset(a, x, sizeof(a))
#define endl '\n'
using namespace std;
 
typedef long long ll;
const int INF = 0x3f3f3f3f;
typedef pair<int, int> pii;
 
const int mod = 1e9 + 7;
 
long long binpow(long long a, long long b) {
  long long res = 1;
  while (b > 0) {
    if (b & 1)
      res = (res * a) % mod;
    a = (a * a) % mod;
    b >>= 1;
  }
  return res;
}
int main() {
  ios::sync_with_stdio(false);
  cin.tie(0);
  vector<int> pr;
  ll x, n;
  cin >> x >> n;
  if (x % 2 == 0) {
    while (x % 2 == 0)
      x /= 2;
    pr.pb(2);
  }
  for (int i = 3; i * i <= x; i += 2) {
    if (x % i == 0) {
      pr.pb(i);
      while (x % i == 0)
        x /= i;
    }
  }
  if (x > 1)
    pr.pb(x);
  ll ans = 1;
  for (auto it : pr) {
    ll tmp = it;
    ll cnt = 0;
    while (tmp <= n) {
      cnt += (n / tmp);
      if (n / tmp < it)
        break;
      tmp *= it;
    }
    if (cnt == 0) continue;
    ans = ans * binpow(it, cnt) % mod;
  }
  cout << ans;
  return 0;
}