kizumi_header_banner_img

夜晚 未来 永远 恐惧 梦见 终结

加载中

文章导读

LMCache源码解析0001_cache_engine.py


avatar
Gouki 2026年1月24日 43

LMCache—V0.3.7部分源码解析

背景介绍

日常大模型推理 vllm 存储侧优化 测数据 🤪 人已测傻治好了,也流口水了。 vllm+lmcache 测 多轮对话 benchmark时单机多卡 hit 0  load token为负数  单机单卡 hit命中 load正常 问题定位

代码阅读

核心类:lmcache/cache_engine.py 下的LMCacheEngine类

初始化

https://gitcode.com/GitHub_Trending/lm/LMCache/blob/v0.3.7/lmcache/cache_engine.py

LMCacheEngine类

1.配置信息和元信息(看多了其实就是自己的配置加模型的配置加分布式环境的配置)

2.存储后端

3.监视的指标

def __init__(
        self,
        config: LMCacheEngineConfig,
        metadata: LMCacheEngineMetadata,
    ):
        “””
        raises: RuntimeError if the loaded configuration does not
            match the current configuration
        “””
        self.config = config
        self.metadata = metadata
        self.chunk_size = config.chunk_size
        self.save_decode_cache = config.save_decode_cache
        self.miss_tokens_count = 0
        self.hit_tokens_count = 0
        self.hit_rate = 0.0
        self.engine_ = CreateStorageBackend(
            config, metadata, “cuda” if torch.cuda.is_available() else “cpu”
        )
        logger.debug(f”Current storage backend type {type(self.engine_)}”)
        InitializeUsageContext(config, metadata)
        self.stats_monitor = LMCStatsMonitor.GetOrCreate()

缓存设计-拿什么算 怎么算

主要是看他怎么hash的 用的python 原生hash  看他hash的设计 多层级hash 只有当部分hash满足 再继续下面的hash 类似区块链的链式校验

def _make_key(self, chunk_hash: int, fmt: str) -> CacheEngineKey:
        return CacheEngineKey(
            fmt, # 格式标识
            self.metadata.model_name, #模型名
            self.metadata.world_size, #
            self.metadata.worker_id, #
            chunk_hash,
        )
def _hash(
        self,
        tokens: torch.Tensor,
        prefix_hash: int,
    ) -> int:
        return hash((prefix_hash, tuple(tokens.tolist())))
def _prefix_hash(
        self,
        token_chunks: Iterable[torch.Tensor],
        num_skip_chunk: Optional[int] = 0,
    ) -> List[int]:
        prefix_hash = self._get_init_hash()
        prefix_hashes = []
        for token_chunk in token_chunks:
            prefix_hash = self._hash(token_chunk, prefix_hash)
            prefix_hashes.append(prefix_hash)
        return prefix_hashes[num_skip_chunk:]

KV缓存的分块与序列化

分块缓存  面对变长序列 用一个固定长度的块 分块存储  固定大小 chunk_size  以有限面对无限 以定长面对变长

先切

def _chunk_tokens(
        self,
        tokens: torch.Tensor,
    ) -> Iterable[torch.Tensor]:
        “””
        Chunk the tokens into chunks of size self.chunk_size.
        :param tokens: the input tokens, with shape [seq_len]
            device: the target device after chunking
        :return: a generator of chunks of tokens, each with
                shape [chunk_size]
        “””
        # TODO(Jiayi): the following step can be parallelized
        tokens = tokens.cpu()
        for i in range(0, len(tokens), self.chunk_size):
            yield tokens[i : i + self.chunk_size]

整合   这对双向转换方法 的目的是 从tuple结构的kv张量转换成连续内存块 和从连续内存块转换为tuple结构的kv张量   大口吞 大口吐

def _tuple_kv_to_blob(
        self,
        kv_tensors: KVCache,
    ) -> torch.Tensor:
        “””Convert the nested tuple of kv tensors to a single
        big tensor with 2 extra dimensions
        “””
        k_temp = []
        v_temp = []
        for kv_layer in kv_tensors:
            k_temp.append(kv_layer[0])
            v_temp.append(kv_layer[1])
        k_tensor_blob = torch.stack(k_temp)
        v_tensor_blob = torch.stack(v_temp)
        # kv_tensors: [num_layer, 2, num_tok, num_kv_head, head_size]
        kv_tensors_flatten = torch.stack((k_tensor_blob, v_tensor_blob))
        kv_tensors_flatten = kv_tensors_flatten.permute([1, 0, 2, 3, 4])
        return kv_tensors_flatten
    def _blob_to_tuple_kv(
        self,
        blob: torch.Tensor,
    ) -> KVCache:
        “””
        Convert a single big tensor to the nested tuple of kv tensors
        “””
        outer_unbound = torch.unbind(blob, dim=0)
        return tuple(
            (inner_tensor[0], inner_tensor[1]) for inner_tensor in outer_unbound
        )

存储策略与批量操作优化

store方法
1.掩码 kv_tensors_mask

2.判断重复 _make_chunks_skip_existing

3.批量操作 engine_.batched_put

@_lmcache_nvtx_annotate
    @torch.inference_mode()
    def store(
        self,
        tokens: torch.Tensor,
        kv_tensors_raw: KVCache,
        kv_tensors_mask: Optional[torch.Tensor] = None,
        skip_existing=True,
        blocking=True,
    ) -> None:
        “””
        Store the KV cache of the tokens into the cache engine.
        Format: either ‘huggingface’ or ‘vllm’
                For huggingface,
                it should have the shape of
                [num_heads, num_tokens, head_size]
                For vllm,
                it should have the shape of
                [num_tokens, num_heads, head_size]
        :param tokens: the input tokens, with shape [seq_len]
        :param kv_tensors_raw: the kv cache of the tokens, in
            the format of nested tuples. The number of tokens
            in the kv_tensors_raw should be the same as trues in
            kv_tensors_mask if mask is not None. Otherwise,
            it should be the same as the input tokens.
        :param kv_tensors_mask: a boolean mask of tokens indicating
            which tokens’ KV Cache should be stored. Only support
            suffix mask. None is taken as trues for all tokens.
            len(kv_tensors_mask) should be the same as len(tokens)
            number of true should be the same as kv_tensors_raw token
            number.
        :param skip_existing: whether to skip the existing chunks
        :param blocking: whether to wait for the store operation to finish
        :return: None
        Note:
            The KV cache should NOT have the “batch” dimension.
        “””
        start_time = time.perf_counter()
        monitor_req_id = self.stats_monitor.on_store_request(
            self._num_tokens_in_kv(kv_tensors_raw, self.metadata.fmt)
        )
        fmt = self.metadata.fmt
        if kv_tensors_mask is None:
            kv_tensors_mask = torch.ones_like(tokens, dtype=torch.bool)
        assert len(tokens.shape) == 1, f”Invalid shape of tokens: {tokens.shape}”
        assert len(kv_tensors_mask.shape) == 1, (
            f”Invalid shape of mask: {kv_tensors_mask.shape}”
        )
        assert len(tokens) == len(kv_tensors_mask), (
            “token length does not match mask length”
        )
        # NOTE(Sixian): Now kv_tensors_mask always a suffix mask.
        num_skip_tok = len(kv_tensors_mask) – torch.sum(kv_tensors_mask)
        assert num_skip_tok == 0 or skip_existing, (
            “When skip_existing is False, the mask must cover all tokens”
        )
        num_skip_chunk = num_skip_tok // self.chunk_size
        assert num_skip_tok == num_skip_chunk * self.chunk_size, (
            “Store KV mask should align to chunk size”
        )
        assert (
            len(tokens) == self._num_tokens_in_kv(kv_tensors_raw, fmt) + num_skip_tok
        ), “Number of tokens in the kv cache does not match the input tokens”
        kv_tensors = self._tuple_kv_to_blob(kv_tensors_raw)
        “”” chunk the tokens and the kv caches “””
        chunk_hashes_and_kvs = self._make_chunks(
            tokens, kv_tensors, fmt, num_skip_chunk, skip_existing=skip_existing
        )
        if not blocking:
            chunk_hashes_and_kvs = list(chunk_hashes_and_kvs)
        end_make_chunks = time.perf_counter()
        “”” store them into the dictionary “””
        n_chunks = self.engine_.batched_put(
            (
                (self._make_key(chunk_hash, fmt), kv_chunk)
                for chunk_hash, kv_chunk in chunk_hashes_and_kvs
            ),
            blocking=blocking,
        )
        end_time = time.perf_counter()
        logger.info(
            f”Stored/updated {n_chunks} chunks, total time “
            f”{end_time – start_time:.2f}s, make chunks time “
            f”{end_make_chunks – start_time:.2f}s”
        )
        self.stats_monitor.on_store_finished(monitor_req_id)

检索机制与前缀匹配优化

1.分块哈希计算

2.批量查询

3.结果拼接

4.命中统计

# prefix caching only needs a mask_len
    # but non-prefix might need an roi
    @_lmcache_nvtx_annotate
    @torch.inference_mode()
    def retrieve(
        self,
        tokens: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_tuple: bool = True,
    ) -> Tuple[Union[KVCache, torch.Tensor], torch.Tensor]:
        “””
        Retrieve the KV cache of the tokens from the cache engine. The
        retrieved KV cache should be a prefix of the input tokens.
        The KV cache of the tokens, in the format of nested
        tuples or a single tensor with shape [num_layers, 2, hidden_dim,
        num_tokens] (huggingface) or [num_layers, 2, num_tokens,
        hidden_dim] (vllm).
        Will be an empty tuple if no kv cache is retrieved (no matter
        return_tuple is True or not).
        :param tokens: the input tokens, with shape [seq_len]
        :param mask: a boolean mask of tokens indicating which tokens’
            KV Cache should be retrieved. Currently, only support
            suffix mask.
        :param return_tuple: whether to return the kv cache as a tuple or a
            single tensor
        :return: Tuple[ kv_tensors , ret_mask] indicate which tokens
            are retrieved
        “””
        num_skip_chunk = 0
        num_skip_tok = 0
        ret_mask = torch.ones_like(tokens, dtype=torch.bool)
        if mask is not None:
            num_skip_tok = len(mask) – torch.sum(mask)
            num_skip_chunk = num_skip_tok // self.chunk_size
        ret_mask[:num_skip_tok] = False
        monitor_req_id = self.stats_monitor.on_retrieve_request(
            len(tokens) – num_skip_tok
        )
        st = time.perf_counter()
        fmt = self.metadata.fmt
        chunk_hashes = self._prefix_hash(self._chunk_tokens(tokens), num_skip_chunk)
        retrival_iterator = self.engine_.batched_get(
            (self._make_key(chunk_hash, fmt) for chunk_hash in chunk_hashes),
        )
        retrieved_kv_chunks = []
        for chunk in retrival_iterator:
            if chunk is None:
                break
            retrieved_kv_chunks.append(chunk)
        “”” concatenate the kv cache “””
        dim = None
        match fmt:
            case “huggingface”:
                dim = 1
            case “vllm”:
                dim = 0
            case _:
                raise ValueError(f”Invalid format: {fmt}”)
        if len(retrieved_kv_chunks) == 0:
            logging.info(“Retrieved 0 chunks”)
            self.miss_tokens_count += tokens.shape[0]
            ret_mask[:] = False
            self.stats_monitor.on_retrieve_finished(monitor_req_id, 0)
            return (), ret_mask
        # drop extra tokens in the first chunk
        extra_token_len = num_skip_tok – num_skip_chunk * self.chunk_size
        retrieved_kv_chunks[0] = self._slice_kv_at(
            extra_token_len, retrieved_kv_chunks[0], fmt
        )[0]
        ret: Union[KVCache, torch.Tensor]
        if return_tuple:
            st2 = time.perf_counter()
            ret = self._blob_to_tuple_kv(torch.cat(retrieved_kv_chunks, dim=dim + 2))
            ed2 = time.perf_counter()
            logger.info(
                f”Concatenated {len(retrieved_kv_chunks)} chunks “
                f”– elapsed time {ed2 – st2}”
            )
            retrieved_token_count = 0 if len(ret) == 0 else ret[0][0].shape[dim]
        else:
            ret = torch.cat(retrieved_kv_chunks, dim=dim + 2)
            retrieved_token_count = 0 if ret.numel() == 0 else ret.shape[dim + 2]
        ed = time.perf_counter()
        self.hit_tokens_count += retrieved_token_count
        self.miss_tokens_count += len(tokens) – num_skip_tok – retrieved_token_count
        self.hit_rate = self.hit_tokens_count / (
            self.miss_tokens_count + self.hit_tokens_count
        )
        logger.info(
            f”Retrieved {len(retrieved_kv_chunks)} chunks “
            f”({retrieved_token_count} tokens in total) –“
            f”hit rate {self.hit_rate:.2%} — “
            f”elapsed time {ed – st}”
        )
        ret_mask[num_skip_tok + retrieved_token_count :] = False
        self.stats_monitor.on_retrieve_finished(monitor_req_id, retrieved_token_count)
        return ret, ret_mask

 

日志定位

vllm v1 适配 lmcache

lmcache/integration/vllm/vllm_v1_adapter.py

查看存

def wait_for_save(self):

logger.info(
                “Storing KV cache for %d out of %d tokens “
                “(skip_leading_tokens=%d) for request %s”,
                len(token_ids) – skip_leading_tokens,
                len(token_ids),
                skip_leading_tokens,
                request.req_id,
            )

查看命中

@_lmcache_nvtx_annotate
    def get_num_new_matched_tokens(
        self,
        request: “Request”,
        num_computed_tokens: int,
    ) -> Optional[int]:

logger.info(
            “Reqid: %s, Total tokens %d, LMCache hit tokens: %d, need to load: %d”,
            request.request_id,
            request.num_tokens,
            num_external_hit_tokens,
            need_to_allocate,
        )

 

 

 

 



评论(0)

查看评论列表

暂无评论


发表评论

个人信息

avatar

面向kpi编程,python的狗。

16
文章
0
评论
1
用户

分类

最新评论

    广告 10-11