분할 정복을 이용한 거듭제곱
시작하기 전에 분할 정복이 무엇일까요? 분할 정복은 문제를 작은 서브 문제로 나누어 푸는 방법입니다. 분할 정복 자체에 대해서는 다음 챕터에서 자세히 다루도록 하겠습니다.
어떤 정수 \( a \)를 \( n \)번 곱하는 문제를 생각해봅시다. 즉, \( a^n \)을 구하는 문제입니다. 정답이 너무 커질 수 있으니 \( 10^9 + 7 \)로 나눈 나머지를 구해야 합니다. 단순히 1에 \( a \)를 \( n \)번 곱해 \( O(n) \)에 처리할 수 있습니다.
int a, n;
cin >> a >> n;
const int MOD = 1e9 + 7;
int ans = 1;
for (int i = 0; i < n; i++)
ans = (ans * a) % MOD;
cout << ans << '\n';
하지만, 조금 더 빠른 방법이 있습니다. \( a^n \)을 \( a^{n - 1} \cdot a \)으로 생각할 수도 있지만, \( a^n \)을 \( a^{n / 2} \cdot a^{n / 2} \)로 생각할 수도 있습니다. 즉, 아래와 같은 점화식을 만들 수 있습니다.
\[ a^n =
\begin{cases}
\left(a^{n/2}\right)^2 & \text{if } n \text{ is even} \\
\left(a^{(n-1)/2}\right)^2 \cdot a & \text{if } n \text{ is odd}
\end{cases} \]
이렇게 지수를 절반씩 줄여나가는 방식을 분할 정복을 이용한 거듭제곱이라고 합니다. \( n \)제곱을 구하는 문제를 \( \frac{n}{2} \) 제곱을 구하는 문제로 나누어 푸는 것입니다. 이렇게 하면 시간복잡도가 \( O(\log n) \)으로 줄어듭니다. 코드로 구현하면
constexpr int MOD = 1e9 + 7;
int pow(int a, int n) {
if (n == 0) return 1;
if (n % 2 == 0) {
int half = pow(a, n / 2);
return (half * half) % MOD;
} else {
return (pow(a, n - 1) * a) % MOD;
}
}
이제 어떤 수를 \( n \) 제곱 한 값을 \( O(\log n) \)에 구할 수 있습니다. 위 코드에서는 재귀함수를 통해서 분할 정복을 구현했지만, 재귀를 사용하지 않고 반복문을 통해서도 구현할 수 있습니다. 반복문을 사용하는 것이 미묘하지만 조금 더 빠르기 때문에 자주 사용됩니다.
constexpr int MOD = 1e9 + 7;
int pow(int a, int n) {
int ans = 1;
while (n > 0) {
if (n % 2 == 1) {
ans = (ans * a) % MOD;
}
a = (a * a) % MOD;
n /= 2;
}
return ans;
}