알고리즘/백준(BOJ)

[백준/C++] 10830번 행렬 제곱

beomseok99 2022. 9. 29. 14:16
728x90

https://www.acmicpc.net/problem/10830

 

10830번: 행렬 제곱

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

www.acmicpc.net

얼핏보면 그냥 행렬 곱을 반복해주면 되는거 아닌가? 싶을 수도 있다.

하지만! 최대 천억제곱까지 주어지므로 O(N)에는 해결할 수 없다!

때문에 빠른 거듭제곱 알고리즘을 사용해야한다.

이 알고리즘은 거듭제곱을 O(logN)에 해결할 수 있도록 해준다!

즉, N이 홀수일 땐 밑을 하나 꺼내서 N을 짝수로 만든다.

N이 짝수일 땐 밑을 제곱하고 N을 2로 나눈다.

그럼 어떻게 이 문제를 풀어야할까?

우선 행렬곱을 해주는 함수를 하나 만든다. 배열을 파라미터로 넘길 땐 주소가 넘어가므로, 함수의 타입을 정해서 return해주지 않아도 수정된 값을 반영할 수 있다는 점을 유의!

문제 풀이 방법

N=7인 경우를 예시로 들어보자.

  1. 우선 지수가 7(홀수)이니까, 단위행렬에 자기 자신을 곱해준다. (단위행렬에 어떤 한 행렬을 곱하면 자기 자신이 나온다!!) -> 현재 ans는 matrix의 1승
  2. 그리고 matrix * matrix를 해서 matrix^2을 만들어준다. -> 현재 matrix는 matrix^2
  3. 제곱을 하나 만들었으니까, 지수를 2로 나눠준다. (bottom-up 풀이방식!)
  4. 나눠진 지수가 3(홀수)이니까 아까 만들어둔 matrix^1에 matrix^2를 곱해준다 -> 현재 ans는 matrix의 3승
  5. 그리고 다시 matrix^2 * matrix^2를 해서 matrix^4를 만들어준다. -> 현재 matrix는 matrix^4
  6. 지수를 다시 2로 나누는데, 3/2 = 1 이므로 위 과정을 다시 반복한다.
  7. ans = ans(=matrix의 3승) * matrix(=matrix의 4승) => matrix^7, 즉 우리가 찾고자 했던 값이고 지수는 나누기 2할 경우 0이 되므로 while문은 종료하게 된다. 그리고 ans를 출력해주면 끝!

※ ans = matrix^3 = matrix * matrix^2이고, matrix = matrix^4 = matrix^2 * matrix^2

빠른 거듭제곱 알고리즘이 적용된 것을 확인할 수 있다.

#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;

long long N, B;
long long matrix[5][5];
long long tmp[5][5];
long long ans[5][5];

void multi(long long a[5][5], long long b[5][5]){
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            tmp[i][j] = 0; // 행렬 초기화
            for (int k = 0; k < N; k++) {
                tmp[i][j] += (a[i][k] * b[k][j]); // 행렬곱
            }
            tmp[i][j] %= 1000;
        }
    }

    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            a[i][j] = tmp[i][j];
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    cin >> N >> B;

    for (int i = 0; i < N; i++){
        for (int j = 0; j < N; j++) {
            cin >> matrix[i][j];
        }
        ans[i][i] = 1; // 정답행렬은 단위행렬로
    }

    while (B > 0){
        if (B % 2 == 1){
            multi(ans, matrix); // 단위행렬에 A행렬 곱하기
        }
        multi(matrix, matrix);
        B /= 2;
    }

    for (int i = 0; i < N; i++){
        for (int j = 0; j < N; j++) {
            cout << ans[i][j] << " ";
        }
        cout << '\n';
    }
    return 0;
}
728x90