Skip to content

implementation

文件信息

  • 📄 原文件:implementation.py
  • 🔤 语言:python

哈希表实现 包含链地址法和开放地址法的哈希表实现,以及 LRU 缓存。

完整代码

python
from typing import TypeVar, Generic, Optional, List, Iterator, Tuple
from collections import OrderedDict

K = TypeVar('K')
V = TypeVar('V')


# ============================================================
#                    链地址法哈希表
# ============================================================

class HashTableChaining(Generic[K, V]):
    """
    哈希表实现(链地址法)

    使用链表解决冲突
    """

    class Node:
        """链表节点"""
        def __init__(self, key: K, value: V, next=None):
            self.key = key
            self.value = value
            self.next = next

    def __init__(self, capacity: int = 16, load_factor: float = 0.75):
        self._capacity = capacity
        self._load_factor = load_factor
        self._size = 0
        self._buckets: List[Optional[self.Node]] = [None] * capacity

    def _hash(self, key: K) -> int:
        """计算哈希值"""
        return hash(key) % self._capacity

    def put(self, key: K, value: V) -> None:
        """
        插入或更新键值对

        时间复杂度: 平均 O(1)
        """
        # 检查是否需要扩容
        if self._size >= self._capacity * self._load_factor:
            self._resize(self._capacity * 2)

        index = self._hash(key)
        node = self._buckets[index]

        # 查找是否已存在
        while node:
            if node.key == key:
                node.value = value  # 更新
                return
            node = node.next

        # 头插法插入新节点
        new_node = self.Node(key, value, self._buckets[index])
        self._buckets[index] = new_node
        self._size += 1

    def get(self, key: K) -> Optional[V]:
        """
        获取值

        时间复杂度: 平均 O(1)
        """
        index = self._hash(key)
        node = self._buckets[index]

        while node:
            if node.key == key:
                return node.value
            node = node.next

        return None

    def remove(self, key: K) -> Optional[V]:
        """
        删除键值对

        时间复杂度: 平均 O(1)
        """
        index = self._hash(key)
        node = self._buckets[index]
        prev = None

        while node:
            if node.key == key:
                if prev:
                    prev.next = node.next
                else:
                    self._buckets[index] = node.next
                self._size -= 1
                return node.value
            prev = node
            node = node.next

        return None

    def contains(self, key: K) -> bool:
        """检查键是否存在"""
        return self.get(key) is not None

    def _resize(self, new_capacity: int) -> None:
        """扩容"""
        old_buckets = self._buckets
        self._capacity = new_capacity
        self._buckets = [None] * new_capacity
        self._size = 0

        # 重新哈希所有元素
        for bucket in old_buckets:
            node = bucket
            while node:
                self.put(node.key, node.value)
                node = node.next

    def __setitem__(self, key: K, value: V) -> None:
        self.put(key, value)

    def __getitem__(self, key: K) -> V:
        value = self.get(key)
        if value is None:
            raise KeyError(key)
        return value

    def __delitem__(self, key: K) -> None:
        if self.remove(key) is None:
            raise KeyError(key)

    def __contains__(self, key: K) -> bool:
        return self.contains(key)

    def __len__(self) -> int:
        return self._size

    def keys(self) -> List[K]:
        """返回所有键"""
        result = []
        for bucket in self._buckets:
            node = bucket
            while node:
                result.append(node.key)
                node = node.next
        return result

    def values(self) -> List[V]:
        """返回所有值"""
        result = []
        for bucket in self._buckets:
            node = bucket
            while node:
                result.append(node.value)
                node = node.next
        return result

    def items(self) -> List[Tuple[K, V]]:
        """返回所有键值对"""
        result = []
        for bucket in self._buckets:
            node = bucket
            while node:
                result.append((node.key, node.value))
                node = node.next
        return result

    def __repr__(self) -> str:
        items = [f"{k}: {v}" for k, v in self.items()]
        return "{" + ", ".join(items) + "}"


# ============================================================
#                    开放地址法哈希表
# ============================================================

class HashTableOpenAddressing(Generic[K, V]):
    """
    哈希表实现(开放地址法 - 线性探测)
    """

    # 标记已删除的槽位
    _DELETED = object()

    def __init__(self, capacity: int = 16, load_factor: float = 0.5):
        self._capacity = capacity
        self._load_factor = load_factor
        self._size = 0
        self._keys: List = [None] * capacity
        self._values: List = [None] * capacity

    def _hash(self, key: K) -> int:
        """计算哈希值"""
        return hash(key) % self._capacity

    def _find_slot(self, key: K) -> Tuple[int, bool]:
        """
        查找槽位

        返回: (索引, 是否找到)
        """
        index = self._hash(key)
        first_deleted = -1

        for _ in range(self._capacity):
            if self._keys[index] is None:
                # 空槽位
                if first_deleted != -1:
                    return first_deleted, False
                return index, False

            if self._keys[index] is self._DELETED:
                # 记录第一个删除位置
                if first_deleted == -1:
                    first_deleted = index

            elif self._keys[index] == key:
                # 找到了
                return index, True

            # 线性探测
            index = (index + 1) % self._capacity

        # 表满了
        if first_deleted != -1:
            return first_deleted, False
        raise RuntimeError("哈希表已满")

    def put(self, key: K, value: V) -> None:
        """插入或更新"""
        if self._size >= self._capacity * self._load_factor:
            self._resize(self._capacity * 2)

        index, found = self._find_slot(key)

        if not found:
            self._size += 1

        self._keys[index] = key
        self._values[index] = value

    def get(self, key: K) -> Optional[V]:
        """获取值"""
        index, found = self._find_slot(key)
        if found:
            return self._values[index]
        return None

    def remove(self, key: K) -> Optional[V]:
        """删除"""
        index, found = self._find_slot(key)
        if found:
            value = self._values[index]
            self._keys[index] = self._DELETED
            self._values[index] = None
            self._size -= 1
            return value
        return None

    def _resize(self, new_capacity: int) -> None:
        """扩容"""
        old_keys = self._keys
        old_values = self._values

        self._capacity = new_capacity
        self._keys = [None] * new_capacity
        self._values = [None] * new_capacity
        self._size = 0

        for i, key in enumerate(old_keys):
            if key is not None and key is not self._DELETED:
                self.put(key, old_values[i])

    def __setitem__(self, key: K, value: V) -> None:
        self.put(key, value)

    def __getitem__(self, key: K) -> V:
        value = self.get(key)
        if value is None:
            raise KeyError(key)
        return value

    def __contains__(self, key: K) -> bool:
        _, found = self._find_slot(key)
        return found

    def __len__(self) -> int:
        return self._size


# ============================================================
#                    LRU 缓存
# ============================================================

class LRUCache:
    """
    LRU 缓存实现

    使用双向链表 + 哈希表
    - 双向链表维护访问顺序
    - 哈希表实现 O(1) 查找

    时间复杂度: get 和 put 都是 O(1)
    """

    class Node:
        def __init__(self, key: int = 0, value: int = 0):
            self.key = key
            self.value = value
            self.prev = None
            self.next = None

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = {}  # key -> Node

        # 使用哨兵节点
        self.head = self.Node()
        self.tail = self.Node()
        self.head.next = self.tail
        self.tail.prev = self.head

    def _add_to_head(self, node: Node) -> None:
        """添加到头部(最近使用)"""
        node.prev = self.head
        node.next = self.head.next
        self.head.next.prev = node
        self.head.next = node

    def _remove_node(self, node: Node) -> None:
        """从链表中删除节点"""
        node.prev.next = node.next
        node.next.prev = node.prev

    def _move_to_head(self, node: Node) -> None:
        """移动到头部"""
        self._remove_node(node)
        self._add_to_head(node)

    def _remove_tail(self) -> Node:
        """删除尾部节点(最久未使用)"""
        node = self.tail.prev
        self._remove_node(node)
        return node

    def get(self, key: int) -> int:
        """获取值"""
        if key not in self.cache:
            return -1

        node = self.cache[key]
        self._move_to_head(node)  # 移到最近使用
        return node.value

    def put(self, key: int, value: int) -> None:
        """插入或更新"""
        if key in self.cache:
            node = self.cache[key]
            node.value = value
            self._move_to_head(node)
        else:
            node = self.Node(key, value)
            self.cache[key] = node
            self._add_to_head(node)

            if len(self.cache) > self.capacity:
                # 删除最久未使用
                removed = self._remove_tail()
                del self.cache[removed.key]


class LRUCacheSimple:
    """
    LRU 缓存简化实现(使用 OrderedDict)
    """

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = OrderedDict()

    def get(self, key: int) -> int:
        if key not in self.cache:
            return -1
        self.cache.move_to_end(key)  # 移到末尾(最近使用)
        return self.cache[key]

    def put(self, key: int, value: int) -> None:
        if key in self.cache:
            self.cache.move_to_end(key)
        self.cache[key] = value
        if len(self.cache) > self.capacity:
            self.cache.popitem(last=False)  # 删除最旧的


# ============================================================
#                    常见哈希表算法
# ============================================================

def two_sum(nums: List[int], target: int) -> List[int]:
    """
    两数之和

    时间复杂度: O(n)
    空间复杂度: O(n)
    """
    seen = {}  # 值 -> 索引

    for i, num in enumerate(nums):
        complement = target - num
        if complement in seen:
            return [seen[complement], i]
        seen[num] = i

    return []


def group_anagrams(strs: List[str]) -> List[List[str]]:
    """
    字母异位词分组

    输入: ["eat","tea","tan","ate","nat","bat"]
    输出: [["eat","tea","ate"],["tan","nat"],["bat"]]

    时间复杂度: O(n * k log k),k 是字符串平均长度
    空间复杂度: O(n * k)
    """
    groups = {}

    for s in strs:
        # 排序后的字符串作为键
        key = ''.join(sorted(s))
        if key not in groups:
            groups[key] = []
        groups[key].append(s)

    return list(groups.values())


def longest_consecutive(nums: List[int]) -> int:
    """
    最长连续序列

    输入: [100, 4, 200, 1, 3, 2]
    输出: 4 (序列 [1, 2, 3, 4])

    时间复杂度: O(n)
    空间复杂度: O(n)
    """
    if not nums:
        return 0

    num_set = set(nums)
    max_length = 0

    for num in num_set:
        # 只从序列起点开始计算
        if num - 1 not in num_set:
            current_num = num
            current_length = 1

            while current_num + 1 in num_set:
                current_num += 1
                current_length += 1

            max_length = max(max_length, current_length)

    return max_length


def contains_duplicate(nums: List[int]) -> bool:
    """
    存在重复元素

    时间复杂度: O(n)
    空间复杂度: O(n)
    """
    return len(nums) != len(set(nums))


def is_anagram(s: str, t: str) -> bool:
    """
    有效的字母异位词

    时间复杂度: O(n)
    空间复杂度: O(1) - 最多26个字母
    """
    if len(s) != len(t):
        return False

    count = {}
    for c in s:
        count[c] = count.get(c, 0) + 1

    for c in t:
        if c not in count:
            return False
        count[c] -= 1
        if count[c] < 0:
            return False

    return True


def find_duplicate(nums: List[int]) -> int:
    """
    寻找重复数(不使用额外空间的方法:快慢指针)

    数组包含 n+1 个整数,范围 [1, n]
    利用索引作为指针,形成环

    时间复杂度: O(n)
    空间复杂度: O(1)
    """
    slow = fast = nums[0]

    # 找到相遇点
    while True:
        slow = nums[slow]
        fast = nums[nums[fast]]
        if slow == fast:
            break

    # 找到环入口
    slow = nums[0]
    while slow != fast:
        slow = nums[slow]
        fast = nums[fast]

    return slow


def subarray_sum(nums: List[int], k: int) -> int:
    """
    和为 K 的子数组数量

    使用前缀和 + 哈希表

    时间复杂度: O(n)
    空间复杂度: O(n)
    """
    count = 0
    prefix_sum = 0
    prefix_count = {0: 1}  # 前缀和 -> 出现次数

    for num in nums:
        prefix_sum += num

        # 查找是否存在前缀和使得差为 k
        if prefix_sum - k in prefix_count:
            count += prefix_count[prefix_sum - k]

        prefix_count[prefix_sum] = prefix_count.get(prefix_sum, 0) + 1

    return count


# ============================================================
#                    测试代码
# ============================================================

if __name__ == "__main__":
    print("=" * 60)
    print("链地址法哈希表测试")
    print("=" * 60)

    ht = HashTableChaining()

    print("\n--- 插入操作 ---")
    ht["apple"] = "苹果"
    ht["banana"] = "香蕉"
    ht["cherry"] = "樱桃"
    ht["date"] = "枣"
    print(f"哈希表: {ht}")

    print("\n--- 查找操作 ---")
    print(f"apple: {ht['apple']}")
    print(f"banana: {ht.get('banana')}")
    print(f"grape exists: {'grape' in ht}")

    print("\n--- 删除操作 ---")
    del ht["date"]
    print(f"删除 date 后: {ht}")

    print("\n--- 遍历操作 ---")
    print(f"keys: {ht.keys()}")
    print(f"values: {ht.values()}")

    print("\n" + "=" * 60)
    print("LRU 缓存测试")
    print("=" * 60)

    lru = LRUCache(3)

    print("\n--- 操作序列 ---")
    operations = [
        ("put", 1, 1),
        ("put", 2, 2),
        ("put", 3, 3),
        ("get", 1, None),
        ("put", 4, 4),  # 淘汰 2
        ("get", 2, None),
        ("get", 3, None),
    ]

    for op in operations:
        if op[0] == "put":
            lru.put(op[1], op[2])
            print(f"put({op[1]}, {op[2]})")
        else:
            result = lru.get(op[1])
            print(f"get({op[1]}) = {result}")

    print("\n" + "=" * 60)
    print("算法测试")
    print("=" * 60)

    print("\n--- 两数之和 ---")
    nums = [2, 7, 11, 15]
    target = 9
    print(f"nums: {nums}, target: {target}")
    print(f"结果: {two_sum(nums, target)}")

    print("\n--- 字母异位词分组 ---")
    strs = ["eat", "tea", "tan", "ate", "nat", "bat"]
    print(f"输入: {strs}")
    print(f"分组: {group_anagrams(strs)}")

    print("\n--- 最长连续序列 ---")
    nums = [100, 4, 200, 1, 3, 2]
    print(f"输入: {nums}")
    print(f"最长连续序列长度: {longest_consecutive(nums)}")

    print("\n--- 和为K的子数组 ---")
    nums = [1, 1, 1]
    k = 2
    print(f"nums: {nums}, k: {k}")
    print(f"子数组数量: {subarray_sum(nums, k)}")

💬 讨论

使用 GitHub 账号登录后即可参与讨论

基于 MIT 许可发布