数据结构与算法之美三之手写LruCache

数据结构与算法之美三之手写LruCache

Scroll Down

专栏第三篇,主要通过三种方式实现LRU 缓存淘汰算法

  1. 数组实现
  2. 单链表实现
  3. 单链表+散列表实现

Talk is cheap. Show me the code.
– Linus Torvalds

数组实现

package com.shockang.study.algorithm.archive.lru

import java.util.StringJoiner

import scala.util.control.Breaks._

/**
 * 数组实现的 LruCache
 *
 * @author Shockang
 */
class ArrayLruCache[K, V](_maxSize: Int) {
  //底层用数组存储数据
  private val array: Array[(K, V)] = new Array(_maxSize)

  //已缓存大小
  private var size: Int = 0

  //最大缓存值
  private val maxSize: Int = _maxSize

  //遍历数组根据 key 查找 value,不存在则返回 null
  def get(k: K): V = {
    if (k == null) throw new IllegalArgumentException("key == null")
    //线程安全,加锁
    this.synchronized {
      for (a <- array if (k.equals(a._1))) return a._2
      null.asInstanceOf[V]
    }
  }

  //插入一个键值对,返回旧值或者 null
  def put(k: K, v: V): V = {
    if (k == null || v == null) throw new IllegalArgumentException("key == null || value == null")
    //线程安全,加锁
    this.synchronized {
      //旧值
      var oldValue: V = null.asInstanceOf[V]
      //判断缓存实际大小是否增加
      var flag: Boolean = true
      var i: Int = 0
      breakable(
        while (i < size) {
          if (k.equals(array(i)._1)) {
            oldValue = array(i)._2
            //替换,缓存大小不增加
            flag = false
            break
          }
          i += 1
        }
      )
      //超出最大容量,则将尾部删除
      if (i >= maxSize) {
        //尾部删除了,缓存大小不增加
        flag = false
        i -= 1
      }
      //i 前面的所有元素的全部往后挪一位,这里要逆序操作
      for (j <- i until 0 by -1) array(j) = array(j - 1)
      //插入头部
      array(0) = (k, v)
      if (flag) size += 1
      oldValue
    }
  }

  //根据 key 删除缓存中的k v 对,返回key 对应的 value,找不到返回 null
  def remove(k: K): V = {
    if (k == null) throw new IllegalArgumentException("key == null")
    this.synchronized {
      //是否存在对应的 KV 对
      var flag: Boolean = false
      var i: Int = 0
      var value: V = null.asInstanceOf[V]
      breakable(
        while (i < size) {
          if (k.equals(array(i)._1)) {
            value = array(i)._2
            flag = true
            break
          }
          i += 1
        }
      )
      //正序
      for (j <- i until size - 1) array(j) = array(j + 1)
      //最后一项设置为 null,方便垃圾回收
      if (flag) {
        array(size - 1) = null
        size -= 1
      }
      value
    }
  }

  //方便打印
  override def toString: String = {
    val sj: StringJoiner = new StringJoiner(",")
    for (i <- 0 until size) {
      sj.add(array(i)._1 + "->" + array(i)._2)
    }
    "[" + sj.toString + "]"
  }
}

单链表实现

package com.shockang.study.algorithm.archive.lru

import java.util.StringJoiner

import scala.util.control.Breaks._

/**
 * 单链表实现的 LruCache
 *
 * @author Shockang
 */
class LinkedListLruCache[K, V](_maxSize: Int) {

  //单链表单个结点
  private class Node[T](_item: T, _next: Node[T]) {
    var item: T = _item
    var next: Node[T] = _next
  }

  //头结点
  private var head: Node[(K, V)] = _
  //哨兵
  private var sentry: Node[(K, V)] = new Node(null, head)

  //缓存 最大容量
  private val maxSize: Int = if (_maxSize >= 0) _maxSize else throw new IllegalArgumentException
  //缓存 当前容量
  var size: Int = 0

  //根据 key 查找 value,不存在则返回 null
  def get(k: K): V = {
    //key 不能为null
    if (k == null) throw new IllegalArgumentException("key == null")
    var cur: Node[(K, V)] = head
    while (cur != null) {
      //使用 equals 匹配
      if (k.equals(cur.item._1)) return cur.item._2
      cur = cur.next
    }
    null.asInstanceOf[V]
  }

  //插入一个键值对,返回旧值或者 null
  def put(k: K, v: V): V = {
    if (k == null || v == null) throw new IllegalArgumentException("key == null || value == null")
    //线程安全,加锁
    this.synchronized {
      var oldValue: V = null.asInstanceOf[V]
      //头结点特殊处理下,防止先删除再新增
      if (head != null && k.equals(head.item._1)) {
        oldValue = head.item._2
        head.item = (k, v)
      } else {
        oldValue = getAndRemove(k)
        //kv 对添加到头部
        addHead(k, v)
        //缓存数量加一
        size += 1
        //容量超了,就删除末尾
        if (size > maxSize) {
          removeLast()
        }
      }
      oldValue
    }
  }

  //根据 key 删除缓存中的k v 对,返回key 对应的 value,找不到返回 null
  def remove(k: K): V = {
    //key 不能为null
    if (k == null) throw new IllegalArgumentException("key == null")
    //线程安全,加锁
    this.synchronized {
      getAndRemove(k)
    }
  }

  //在缓存中查找 key,命中缓存后删除对应 kv 对
  private def getAndRemove(k: K): V = {
    //遍历的当前结点
    var cur: Node[(K, V)] = head
    //遍历的前结点
    var pre: Node[(K, V)] = sentry
    var oldValue: V = null.asInstanceOf[V]
    breakable(
      while (cur != null) {
        //equals 表示匹配
        if (k.equals(cur.item._1)) {
          //旧值
          oldValue = cur.item._2
          //删除当前结点
          pre.next = cur.next
          //如果删除的是头结点,需要重新设置下 head
          if (head == cur) {
            head = cur.next
            sentry = new Node(null, head)
          }
          //缓存数量减一
          size -= 1
          break
        }
        pre = cur
        cur = cur.next
      }
    )
    oldValue
  }

  //插入一个键值对到头部
  private def addHead(k: K, v: V): Unit = {
    //借用哨兵来搞定指针指向的问题
    sentry = new Node((k, v), head)
    head = sentry
    //新建哨兵还指向头结点
    sentry = new Node(null, head)
  }

  //删除最后一项
  private def removeLast(): Unit = {
    //从哨兵开始
    var cur: Node[(K, V)] = sentry
    //总共走了 maxSize-1 步
    for (_ <- 0 until maxSize) {
      cur = cur.next
    }
    //这时 cur 代表倒数第二项
    cur.next = null
  }

  override def toString: String = {
    val sj: StringJoiner = new StringJoiner(",")
    var cur: Node[(K, V)] = head
    while (cur != null) {
      sj.add(cur.item._1 + "->" + cur.item._2)
      cur = cur.next
    }
    "[" + sj + "]"
  }
}

单链表+散列表实现

package com.shockang.study.algorithm.archive.lru

import java.util.StringJoiner

import scala.collection.mutable

/**
 * 单链表+散列表实现的 LruCache,类似于 LinkedHashMap
 *
 * @author Shockang
 */
class ListAndMapLruCache[K, V](_maxSize: Int) {

  //单链表单个结点
  private class Node[T](_item: T, _next: Node[T]) {
    var item: T = _item
    var next: Node[T] = _next
  }

  //头结点
  private var head: Node[K] = _
  //哨兵
  private var sentry: Node[K] = new Node(null.asInstanceOf[K], head)

  //缓存 最大容量
  private val maxSize: Int = if (_maxSize > 0) _maxSize else throw new IllegalArgumentException
  //缓存 当前容量
  var size: Int = 0

  //散列表存储 kv 对
  private val cacheMap: mutable.HashMap[K, V] = mutable.HashMap()

  //key 对应的结点
  private val keyMap: mutable.HashMap[K, Node[K]] = mutable.HashMap()

  //key 对应结点的前一个结点,方便进行删除操作
  private val beforeKeyMap: mutable.HashMap[K, Node[K]] = mutable.HashMap()

  //根据 key 查找 value,不存在则返回 null
  def get(k: K): V = {
    //key 不能为null
    if (k == null) throw new IllegalArgumentException("key == null")
    cacheMap(k)
  }

  //插入一个键值对,返回旧值或者 null
  def put(k: K, v: V): V = {
    if (k == null || v == null) throw new IllegalArgumentException("key == null || value == null")
    //线程安全,加锁
    this.synchronized {
      //头结点特殊处理下
      var oldValue: V = null.asInstanceOf[V]
      if (head != null && k.equals(head.item)) {
        oldValue = cacheMap(k)
      } else {
        oldValue = getAndRemove(k)
        //key添加到头部
        addHead(k)
        //容量超了,就删除末尾
        if (size > maxSize) {
          removeLast()
        }
      }
      //更新所有 map
      cacheMap += k -> v
      keyMap += k -> head
      beforeKeyMap += k -> sentry
      oldValue
    }
  }

  //根据 key 删除缓存中的k v 对,返回key 对应的 value,找不到返回 null
  def remove(k: K): V = {
    //key 不能为null
    if (k == null) throw new IllegalArgumentException("key == null")
    //线程安全,加锁
    this.synchronized {
      getAndRemove(k, isRemove = true)
    }
  }

  //在缓存中查找 key,isRemove代表命中缓存后删除对应 kv 对
  private def getAndRemove(k: K, isRemove: Boolean = false): V = {
    var oldValue: V = null.asInstanceOf[V]
    if (cacheMap.contains(k)) {
      oldValue = cacheMap(k)
      val cur: Node[K] = keyMap(k)
      val pre: Node[K] = beforeKeyMap(k)
      //删除 cur 结点
      pre.next = cur.next
      if (cur.next != null) {
        beforeKeyMap += cur.next.item -> pre
      }
      //头结点需要特殊处理下
      if (head == cur) {
        head = cur.next
        sentry.next = head
      }
      size -= 1
    }
    //不需要删除缓存的时候就没必要删除
    if (isRemove) {
      cacheMap -= k
      keyMap -= k
      beforeKeyMap -= k
    }
    oldValue
  }

  private def removeLast(): Unit = {
    //从哨兵开始
    var cur: Node[K] = sentry
    //总共走了 maxSize-1 步
    for (_ <- 0 until maxSize) {
      cur = cur.next
    }
    //这时 cur 代表倒数第二项
    val last: K = cur.next.item
    //末尾的 Map 缓存也清理一下
    cacheMap -= last
    keyMap -= last
    beforeKeyMap -= last
    cur.next = null
    size -= 1
  }

  //插入key 到头部
  private def addHead(k: K): Unit = {
    //借用哨兵来操作
    sentry = new Node(k, head)
    //需要更新一下 k 对应结点的前驱指针
    if (head != null) {
      beforeKeyMap += head.item -> sentry
    }
    head = sentry
    //新建哨兵还指向头结点
    sentry = new Node(null.asInstanceOf[K], head)
    //缓存加一
    size += 1
  }

  override def toString: String = {
    val sj: StringJoiner = new StringJoiner(",")
    var cur: Node[K] = head
    while (cur != null) {
      sj.add(cur.item + "->" + cacheMap(cur.item))
      cur = cur.next
    }
    "[" + sj + "]"
  }
}

测试

package com.shockang.study.algorithm.archive.lru

object Main extends App {
  //val lru: ArrayLruCache[Int, String] = new ArrayLruCache(5)
  //val lru: LinkedListLruCache[Int, String] = new LinkedListLruCache(5)
  val lru: ListAndMapLruCache[Int, String] = new ListAndMapLruCache(5)
  lru.put(0, "0")
  lru.put(1, "1")
  lru.put(2, "2")
  println(lru)
  println(lru.get(1))
  lru.put(1, "11")
  println(lru.get(1))
  println(lru)
  lru.put(3, "3")
  lru.put(4, "4")
  lru.put(5, "5")
  println(lru)
  println(lru.get(1))
  lru.put(1, "12")
  println(lru.get(1))
  println(lru)
  println(lru.get(3))
  lru.put(3, "31")
  println(lru.get(3))
  println(lru)
  lru.remove(1)
  println(lru)
  lru.remove(3)
  println(lru)
}

输出

[2->2,1->1,0->0]
1
11
[1->11,2->2,0->0]
[5->5,4->4,3->3,1->11,2->2]
11
12
[1->12,5->5,4->4,3->3,2->2]
3
31
[3->31,1->12,5->5,4->4,2->2]
[3->31,5->5,4->4,2->2]
[5->5,4->4,2->2]

个人博客:

CSDN