본문 바로가기
개발공부/Python알고리즘

[Python] SQRT Decomposition(제곱근 분할법) 설명 및 구현 (백준14438, 2042)

by 왜지? 2023. 9. 26.
반응형

이번 포스팅에서는 SQRT Decomposition(제곱근 분할법)의 개념과 Python Code를 설명합니다. Code만 참고하실 분은 포스팅 가장 아랫부분으로 내려가시면 됩니다. 

SQRT Decomposition(제곱근 분할법, 평방분할법) 알고리즘

SQRT Decomposition은 구간에서 최소값, 최대값, 구간합 등을 구할 때 가장 기본적으로 사용되는 알고리즘입니다. 전체 크기가 N인 구간에서 N을 전부 다 탐색하는 것이 아니라,  ROOT(N) 번만 탐색하여 시간을 줄일 수 있습니다.  

 

크기가 N인 배열이 주어졌을때, 아래와 같은 연산을 수행하는 문제를 생각해 보겠습니다. 

    - 1) 구간 l,r ( l ≤ r )을 주고 l ~ r 사이 배열의 최소값(min( A[l], A[l+1],... , A[r-1],A[r])) 값을 구하라 

    - 2) 배열 A에서 i번째 값을 v로 바꾸기. ( A[i] = v )

위 연산을 for문 만을 이용해 푼다면 최소값을 구하는데 O(N)이, M회 연산을 수행하면 O(NM) 시간복잡도를 가집니다. 배열 A값을 바꾸는 2번 연산을 수행하려면 모든 S를 변경해야 하므로 O(N)의 시간복잡도를 가집니다. 이런 상황에서 SQRT Decomposition은 O(lgN) 시간복잡도로 위 문제를 해결할 수 있습니다.  

 

 

SQRT Decomposition(제곱근 분할법, 평방분할법) 구현 및 예시

SQRT Decomposition을 사용하기 위해서는 크기 N인 리스트를 ROOT(N) 크기의 작은 그룹으로 나누고, 각 그룹별로 접근해야 합니다. 아래 N=11일 때의 예시를 보겠습니다. 

 

■ 그룹 세팅하기 

## sqrt(n) 만큼의 그룹 만들기
r = int(n**0.5) ## 루트값
g = n//r ## 그룹개수
if n % r != 0 :
    g += 1
d = [0]*g
for i in range(n):
    if i % r == 0:
        d[i//r] = a[i]
    else:
        if d[i//r] > a[i]:
            d[i//r] = a[i]

위 예시에서 N=11이고, ROOT(N) = 3으로, 3이 그룹의 크기(r)이됩니다. 그룹의 총 개수(g)는 리스트 전체 개수 11을 그룹의 크기 3으로 나눈 수가 3.* 이므로 4개가 됩니다.  또 각 그룹의 최소값을 d[i] 배열에 따로 저장해 둡니다. 

 

■ 수 업데이트하기 

i번째 수를 v로 바꾸는 2번 연산을 위해서는 i번째 수가 포함된 그룹의 최소값만 바꿔주면 됩니다. 즉, i가 포함된 d의 요소만 비교하면 되기 때문에 그룹의 크기(r)에 해당하는 ROOT(N) 개의 값만 탐색하면 됩니다. 

## 수 업데이트 ( 최소값 갱신 )
def update(a,d,r,idx, val):
    a[idx] = val
    group = idx // r
    start = group * r
    end = start + r
    if end > len(a):
        end = len(a)
    d[group] = a[start]
    for i in range(start,end):
        if d[group] > a[i] :
            d[group] = a[i]

우선 바꾸고자 하는 값을 배열 a에서 바꿔주고, 해당 index의 그룹 인덱스(group)를 찾습니다. 그룹 인덱스(group)에 그룹의 크기(r)을 곱해주면 선택된 그룹의 배열 내 인덱스(start)를 알 수 있고, start에 그룹의 크기(r)을 더해주면 그룹의 끝 인덱스(end)를 알 수 있습니다. 

 

가장 마지막 그룹일 경우 그룹의 끝 인덱스(end)가 배열의 길이(N)을 넘을 수 있으니 그럴 경우는 배열의 길이(N)을 end로 설정합니다. 이후 d[group] 에 a[start]부터 a[end] 까지의 값을 탐색하여 최소값을 업데이트 해줍니다.  

 

 

■ 최소값 구하기

## 구간의 최소값 구하기
def query(a,d,r,start,end):
    result = a[start]
    
    ## start, end가 같은 그룹일때 → 그룹 내에서 전체 탐색
    if start//r == end//r :
        for i in range(start,end+1):
            if result > a[i] :
                result = a[i]
        return result

    ## start 위치를 그룹의 첫번째일떄까지 반복
    while True :
        if result > a[start]:
            result = a[start]
        start += 1
        if start % r == 0 :
            break

    ## end 위치를 그룹의 마지막일떄까지 반복
    while True:
        if result > a[end]:
            result = a[end]
        end -= 1
        if end % r == r-1:
            break

    ## 그룹의 대표 최소값을 가져와서 비교
    startG = start // r
    endG = end // r
    for i in range(startG,endG+1):
        if result > d[i]:
            result = d[i]
    return result

최소값을 구할 때는 구하고자 하는 구간의 시작, 끝 인덱스에 따라 방법이 달라집니다. start, end가 같은 그룹에 속할 때는 d배열을 사용하지 않고 그룹 내에서 start~end 사이를 전체 탐색 해야합니다. 구간의 최소값이 start ~ end 사이에 없을 수도 있기 때문입니다. 

 

start, end가 다른 그룹에 속할때는 start의 위치가 각 그룹의 첫 번째 index에 오고 end의 index가 그룹의 마지막에 올 때까지 그룹 내 탐색을 이어갑니다. 이후 start와 end가 각 그룹의 첫 번째와 마지막에 왔다면 탐색을 멈추고 그룹 최소값 배열인 d를 이용하여 비교합니다. 

 

start의 그룹인 startG와 end의 그룹인 endG 사이의 모든 그룹의 최소값을 비교하면 됩니다. 만약 startG가 endG보다 크면 for문은 돌지 않습니다. 이렇게 완성한 코드는 아래와 같습니다.

 

■ 백준 14438 수열과 쿼리 17 파이썬 코드 

>> 백준 14438

#### 14438 수열과 쿼리 17 : 최소값 찾기

## 수 업데이트 ( 최소값 갱신 )
def update(a,d,r,idx, val):
    a[idx] = val
    group = idx // r
    start = group * r
    end = start + r
    if end > len(a):
        end = len(a)
    d[group] = a[start]
    for i in range(start,end):
        if d[group] > a[i] :
            d[group] = a[i]

## 구간의 최소값 구하기
def query(a,d,r,start,end):
    result = a[start]
    
    ## start, end가 같은 그룹일때 → 그룹 내에서 전체 탐색
    if start//r == end//r :
        for i in range(start,end+1):
            if result > a[i] :
                result = a[i]
        return result

    ## start 위치를 그룹의 첫번째일떄까지 반복
    while True :
        if result > a[start]:
            result = a[start]
        start += 1
        if start % r == 0 :
            break

    ## end 위치를 그룹의 마지막일떄까지 반복
    while True:
        if result > a[end]:
            result = a[end]
        end -= 1
        if end % r == r-1:
            break

    ## 그룹의 대표 최소값을 가져와서 비교
    startG = start // r
    endG = end // r
    for i in range(startG,endG+1):
        if result > d[i]:
            result = d[i]
    return result

n = int(input())
a = list(map(int, input().split()))

## sqrt(n) 만큼의 그룹 만들기
r = int(n**0.5) ## 루트값
g = n//r ## 그룹개수
if n % r != 0 :
    g += 1
d = [0]*g
for i in range(n):
    if i % r == 0:
        d[i//r] = a[i]
    else:
        if d[i//r] > a[i]:
            d[i//r] = a[i]

q = int(input())
for _ in range(q):
    t, t1, t2 = map(int, input().split())
    if t == 1 :
        i, v = t1, t2
        update(a,d,r,i-1, v)
    else :
        i,j = t1,t2
        print(query(a,d,r,i-1,j-1))

 

 

■ 백준 2042 구간 합 구하 파이썬 코드 

위와 마찬가지 방법으로 구간합을 구할 수도 있습니다. min을 구하는 부분을 전부 + 연산으로 바꾸면 됩니다. 업데이트 시에는 start, end를 신경 쓰지 않고 값만 수정해 주면 됩니다.

>> 백준 2042 

#### 14438 수열과 쿼리 17 : 구간 합 구하기
## 수 업데이트 ( 최소값 갱신 )
def update(a,d,r,idx, val):
    group = idx//r
    d[group] = d[group] - a[idx] + val
    a[idx] = val

## 구간의 합 구하기
def query(a,d,r,start,end):
    result = 0

    ## start, end가 같은 그룹일때 → 그룹 내에서 전체 탐색
    if start//r == end//r :
        for i in range(start,end+1):            
            result += a[i]
        return result

    ## start 위치를 그룹의 첫번째일떄까지 반복
    while True :
        result += a[start]
        start += 1
        if start % r == 0 :
            break

    ## end 위치를 그룹의 마지막일떄까지 반복
    while True:
        result += a[end]
        end -= 1
        if end % r == r-1:
            break

    ## 그룹의 대표 최소값을 가져와서 비교
    startG = start // r
    endG = end // r
    for i in range(startG,endG+1):
        result += d[i]
    return result

n, m, k = map(int, input().split())
a = [int(input()) for _ in range(n)]

## sqrt(n) 만큼의 그룹 만들기
r = int(n**0.5) ## 루트값
g = n//r ## 그룹개수
if n % r != 0 :
    g += 1
d = [0]*g
for i in range(n):
    d[i//r] += a[i]

q = m+k
for _ in range(q):
    t, t1, t2 = map(int, input().split())
    if t == 1 :
        i, v = t1, t2
        update(a,d,r,i-1, v)
    else :
        i,j = t1,t2
        print(query(a,d,r,i-1,j-1))

 

지금까지 SQRT Decomposition 설명 및 파이썬 예시 구현이었습니다. SQRT Decomposition과 개념이 비슷하지만 속도가 더 빠른 SegmentTree 알고리즘이 궁금하신 분은 아래 글을 참고 부탁드립니다. 

 

>> [파이썬] SegmentTree 설명 및 구현 

 

반응형

댓글