Prefix caching #
llm.inference(
input_tokens: list[int], # N tokens
previous_kv_cache: list[Tensor], # N tokens' kv cache ∪ <N
) -> output_tokens, new_kv_cache
output_tokens: # N' new tokens
new_kv_cache: # kv cache of N + N' tokens
Key: tokens Value: KV cache tensors
class KVCacheStore:
def store(tokens, kv_cache_tensors):
pass
def retrieve(tokens) -> kv_cache_tensors:
pass
Prefix-based matching #
- Tokens 1: ABCDE -> [KV1, KV2, KV3, KV4, KV5]
- Tokens 2: ABCDF -> [KV1, KV2, KV3, KV4, KV6]
kv_cache_store.store("ABCDE", [KV1, KV2, KV3, KV4, KV5])
kv_cache_store.retrieve("ABCD") -> [KV1, KV2, KV3, KV4]
- “Trie”
- “ABCDEF” -> “AB”, “CD”, “EF” -> list of chunked prefix hashes
prefix_hash = ""
for chunk in chunked_tokens: # ["AB", "CD", "EF"]
chunk_hash = hash(prefix_hash + chunk)
prefix_hash = chunk_hash
# Given chunked prefix hashes, chunked kv cache
# store
for chunk_hash, chunk_kv in zip(...):
redis.put(chunk_hash, chunk_kv)
# retrieve
for chunk_hash in ...:
kv_chunk = redis.get(chunk_hash)
if kv_chunk is None:
break
Eviction #
- LRU, LFU…
- “ABCDEF” –> [“AB”, KV1], [“CD”, KV2], [“EF”, KV3]