본문 바로가기
개발공부/Python알고리즘

[Python] 바이너리 인덱스 트리(BIT, 펜윅 트리) 구현

by 왜지? 2023. 6. 4.
반응형

바이너리 인덱스 트리( BIT, 펜윅트리 )

 

[Python] Segment Tree(세그먼트 트리) 설명 및 구현(백준 2042번)

이번 포스팅에서는 Segment Tree(세그먼트 트리)의 개념과 Python Code를 설명합니다. Code만 참고하실 분은 포스팅 가장 아랫부분으로 내려가시면 됩니다. Segment Tree(세그먼트 트리, 구간트리) 알고리즘

won-developer-log.tistory.com

앞서 소개한 세그먼트 트리( Segment Tree )와 비슷한 역할을 하는 Tree가 또 하나 있습니다. 바로 바이너리 인덱스 트리입니다. BIT, 펜윅 트리라고도 불리는 데요, BIT는 Segment Tree에 비해 작은 메모리를 사용합니다. 이를 위해 어떤 수 X를 이진수로 표기했을 때 마지막으로 나오는 1의 위치를 알아야 합니다. 

  • 1 = 1
  • 2 = 10₂
  • 3 = 11
  • 4 = 100₂
  • 5 = 101
  • 6 = 110₂
  • 7 = 111
  • 8 = 1000₂
  • 9 = 1001
  • 10 = 1010₂

위 수처럼 Bold 처리된 부분이 각 수들의 마지막으로 나오는 1의 위치입니다. 이 위치들을 나타내는 값을 list L에 저장하고 L[i] 라고 표현하면, L[1] = 1 , L[2] = 2, L[3] = 1, L[4] = 4, L[8] = 8 가 됩니다. list A에 어떤 값들이 주어졌다면, A[i] 값의 이진수 변환 후 마지막으로 나오는 1의 위치를 L[i]라고 정의한 것입니다. 

 

이제 A[1] ~ A[N]이 주어졌을 때, Tree[i]에는 A[i]부터 앞으로 L[i] 개의 합이 저장됩니다. 아래 그림을 보고 다시 설명해 보겠습니다. 우선 L[i]를 나타낸 표는 아래와 같습니다. L[1] = 1 이므로 A[1] 부터 1개를 나타냅니다. L[2] = 2 이므로 A[2] 부터 앞으로 2개인 A[1]+A[2] 값을 나타냅니다. L[4]는 4이므로 A[4]부터 앞으로 4개인 A[1] ~ A[4] 까지의 합을 나타냅니다.   

L[ i ]를 구하기 위해서는 비트의 연산을 알아야합니다. 

L[ i ] = i & -i 로 정의할 수 있는데 이유는 아래 연산과정을 통해 증명됩니다. 

      -num = ~num + 1
       num = 100110101110101100000000000
      ~num = 011001010001010011111111111
      -num = 011001010001010100000000000
num & -num = 000000000000000100000000000

-i 는 i의 이진법 표기에서 0과 1을 바꾼 후 1을 더해주면 되기 때문에 위 연산이 성립됩니다. 

 

이번에는 Tree[i]를 살펴보겠습니다. Tree[i]에는 A[i]에서 앞으로 L[i] 개의 합을 저장한다고 했습니다. 

A = [3, 2, 5, 7, 10, 3, 2, 7, 8, 2, 1, 9, 5, 10, 7, 4] 라고 정의한 경우 앞서 구한 L[i] 값으로 몇 칸 앞의 값까지 더하여 Tree[i]에 저장할지 결정합니다. 

 

Tree[1]의 경우, L[1] = 1 이므로 A[1] 값 하나인 3만 저장합니다. 

Tree[2]의 경우, L[2] = 2 이므로 A[2] 부터 2칸 앞인 A[1] 값까지 더하여 A[1]+A[2] = 5를 저장합니다. 

Tree[3]의 경우, L[3] = 1 이므로 A[3] 값 하나인  5만 저장합니다.

Tree[4]의 경우, L[4] =4 이므로 A[4] 부터 16칸 앞인 A[1] 값까지, A[1] + ... + A[16] = 17을 저장합니다. 

 

동일하게 Tree 값을 업데이트할 수 있습니다. 

 

Tree[13]의 경우, L[13] = 1 이므로 A[13] 값 하나인 5만 저장합니다.

Tree[14]의 경우, L[14] = 2 이므로 A[14] 부터 2칸 앞인 A[13] 값까지 더하여 A[13]+A[14] = 15를 저장합니다. 

Tree[15]의 경우, L[15] = 1 이므로 A[15] 값 하나인  7만 저장합니다.

Tree[16]의 경우, L[16] =16 이므로 A[16] 부터 4칸 앞인 A[1] 값까지, A[1] + A[2] + A[3] + A[4] = 85를 저장합니다. 

 

감이 오시나요? 

이제는 앞서 살펴본 문제처럼 구간의 합을 구하는 방법과 값이 업데이트 된 경우 구간에 반영하는 방법을 살펴보겠습니다. 

 

 

 

 

 

구간 합 구하기

위에서 구한 Tree를 가지고 구간합을 구해보도록 하겠습니다. 

예를 들어 A[1]부터 A[13]까지의 합을 구한다고 하겠습니다. 13을 이진수로 표현하면 1101₂ 입니다. 

앞서 말한 Tree의 성질을 이용하면 A[1] + ... + A[13] = Tree[1101₂] + Tree[1100₂] + Tree[1000₂] 으로 나타낼 수 있습니다. 

위 표에서 1101₂ → 1100₂ → 1000₂ 으로 마지막 1의 위치를 제거하면서 합을 더해주는 것 보이시나요? 이것을 코드로 나타내면 아래와 같습니다. 

# i번째 수까지의 누적 합을 계산하는 함수
def sum(i):
    result = 0
    while i > 0:
        result += tree[i]
        # 0이 아닌 마지막 비트만큼 빼가면서 이동
        i -= (i & -i)
    return result

 

모든 i에 대하여 A[i] 까지의 합을 구하는 과정을 그림으로 나타내면 아래와 같습니다. 

이때 구간 합 A[ start ] ~ A [ end ]는 A[1] +... + A[start] 에서 A[1] + ... + A[start -1] 을 뺀 값이므로 sum(end) - sum(start-1)로 구현할 수 있습니다. 

 

값 업데이트 하기

배열의 수를 변경하고자 한다면 해당 배열 값을 수정하고 Tree에서도 해당 배열이 포함된 값들을 모두 업데이트 해줘야하 합니다. 아래와 같이 업데이트하고자 하는 수의 인덱스 i를 이진수로 변경한 후 마지막 1의 값을 더하는 방식으로 구현할 수 있습니다. 

# i번째 수를 dif만큼 더하는 함수
def update(i, dif):
    while i <= n:
        tree[i] += dif
        i += (i & -i)

i += ( i & -i ) 부분을 보면, 이진수의 가장 마지막 1을 더해가며 Tree의 끝가지 진행해 나가는 것을 알 수 있습니다.  

아래 그림을 보면, i를 업데이트 했을 때 변경해야 하는 모든 Tree[i]가 그려져 있습니다. 

예를 들어 A[13] 값을 업데이트한다면, 우선 A[13]을 업데이트하고, A[13] = 1101₂의 마지막에 0001₂을 더한 A[14]=1110₂를 업데이트합니다. 그 후 A[14]=1110₂에 0010₂을 더한 A[16] = 10000₂을 업데이트하고 종료합니다. 

 

예제 ) 백준 2042번 문제 : 구간 합 구하기

- 앞서 살펴몬 Segment Tree 풀이보다 메모리를 적게 사용하여 풀 수 있습니다. 

- O(logN)의 시간 복잡도를 보장합니다. 

import sys
input = sys.stdin.readline

# 데이터의 개수(n), 변경 횟수(m), 구간 합 계산 횟수(k)
n, m, k = map(int, input().split())

arr = [0] * (n + 1)
tree = [0] * (n + 1)

# i번째 수까지의 누적 합을 계산하는 함수
def prefix_sum(i):
    result = 0
    while i > 0:
        result += tree[i]
        # 0이 아닌 마지막 비트만큼 빼가면서 이동
        i -= (i & -i)
    return result

# i번째 수를 dif만큼 더하는 함수
def update(i, dif):
    while i <= n:
        tree[i] += dif
        i += (i & -i)

# start부터 end까지의 구간 합을 계산하는 함수
def interval_sum(start, end):
    return prefix_sum(end) - prefix_sum(start - 1)

for i in range(1, n + 1):
    x = int(input())
    arr[i] = x
    update(i, x)

for i in range(m + k):
    a, b, c = map(int, input().split())
    # 업데이트(update) 연산인 경우
    if a == 1:
        update(b, c - arr[b]) # 바뀐 크기(dif)만큼 적용
        arr[b] = c
    # 구간 합(interval sum) 연산인 경우
    else:
        print(interval_sum(b, c))

 

이상으로 Binary Index Tree( BIT, 펜윅트리 )에 대한 포스팅을 마칩니다. 

반응형

댓글