자료구조 - BST 학습

이진 탐색 트리(Binary-Search-Tree)

  • 모든 원소는 상이한 키를 갖는다.
  • 왼쪽 서브트리에 있는 원소의 키들은 그 루트의 키보다 작다.
  • 오른쪽 서브트리에 있는 원소의 키들은 그 루트의 키보다 크다.
  • 왼쪽 서브트리와 오른쪽 서브트리도 모두 이진 탐색 트리이다.

BST 코드 구현 및 해설

  • 각 노드는 아래와 같이 left, right, key와 같이 3개의 필드를 가진다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    # 노드 생성
    class TreeNode:
    def __init__(self, key):
    self.__key = key
    self.__left = None
    self.__right = None

    def __del__(self):
    print(f'key {self.__key} is deleted')
    # get 함수
    @property
    def key(self):
    return self.__key
    # set 함수
    @key.setter
    def key(self, key):
    self.__key = key

    @property
    def left(self):
    return self.__left

    @left.setter
    def left(self, left):
    self.__left = left

    @property
    def right(self):
    return self.__right

    @right.setter
    def right(self, right):
    self.__right = right

    class BST:
    def __init__(self):
    self.root = None

    # 루트 조회 함수
    def get_root(self):
    return self.root

    # 전위순회
    def preorder_traverse(self, cur, func):
    if not cur:
    return

    func(cur)
    self.preorder_traverse(cur.left, func)
    self.preorder_traverse(cur.right, func)

    # 특정 key 삽입 함수
    def insert(self, key):
    new_node = TreeNode(key)

    cur = self.root
    # root 가 존재하지 않는 경우
    if not cur:
    # 새로운 노드를 루트로 설정
    self.root = new_node
    return

    while True:
    parent = cur
    # insert 하려는 key값이 현재 cur가 참조하고 있는 key값보다 작은 경우
    if key < cur.key:
    # cur의 왼쪽 자식 노드를 다시 cur가 참조한다.
    cur = cur.left
    # cur가 참조하는 노드가 없는 경우
    if not cur:
    # 새로운 노드를 parent가 참조하는 노드의 왼쪽 자식으로 설정한다.
    parent.left = new_node
    return
    # insert하려는 key값이 현재 cur가 참조하고 있는 key 값보다 큰 경우
    else:
    # cur가 참조하고 있는 노드의 오른쪽 자식을 다시 cur가 참조한다.
    cur = cur.right
    if not cur:
    # 새로운 노드를 parent가 참조하는 노드의 오른쪽 자식으로 설정한다.
    parent.right = new_node
    return

    # 특정 target 검색하여 해당 key 값 반환
    def search(self, target):
    cur = self.root
    while cur:
    if cur.key == target:
    return cur.key
    elif cur.key > target:
    cur = cur.left
    elif cur.key < target:
    cur = cur.right
    # target에 해당하는 key 값 없는 경우 None 반환
    return None


    def __remove_recursion(self, cur, target):
    if not cur:
    return None, None
    elif target < cur.key:
    cur.left, rem_node = self.__remove_recursion(cur.left, target)
    elif target > cur.key:
    cur.right, rem_node = self.__remove_recursion(cur.right, target)
    # target == cur.key 인 경우,
    else:
    # 제거할 target node에 자식 노드가 없는 경우
    if not cur.left and not cur.right:
    rem_node = cur
    cur = None
    # 제거할 target node가 왼쪽 자식만 갖고 있는 경우
    elif not cur.right:
    rem_node = cur
    cur = cur.left
    # 제거할 target node가 오른쪽 자식만 갖고 있는 경우
    elif not cur.left:
    rem_node = cur
    cur = cur.right
    # 제거할 target node가 양쪽 자식 모두 갖고 있는 경우
    else:
    # 제거할 target node의 왼쪽 서브트리중에서 제일 큰 key 값을 갖는 노드를 replace가 참조하게 한다.
    replace = cur.left
    while replace.right:
    replace = replace.right
    # replace가 참조하는 key 값과 cur 가 참조하고 있는 key 값을 서로 교체한다.
    cur.key, replace.key = replace.key, cur.key
    cur.left, rem_node = self.__remove_recursion(cur.left, replace.key)
    return cur, rem_node

    # 특정 target 제거 함수
    def remove(self, target):
    self.root, removed_node = self.__remove_recursion(self.root, target)
    if removed_node:
    removed_node.left = removed_node.right = None
    return removed_node.key

    if __name__=="__main__":
    print('*'*100)
    bst = BST()

    bst.insert(6)
    bst.insert(3)
    bst.insert(2)
    bst.insert(4)
    bst.insert(5)
    bst.insert(8)
    bst.insert(10)
    bst.insert(9)
    bst.insert(11)

    f = lambda x: print(x.key, end=' ')

    bst.preorder_traverse(bst.get_root(), f)
    print()

    print(f'searched key : {bst.search(8)}')

    print(f'removed key : {bst.remove(6)}')

    bst.preorder_traverse(bst.get_root(), f)
    print()
    print('*'*100)