package main
import "fmt"
const (
RED bool = true
BLACK bool = false
)
type RBTree struct {root *Node}
func NewRBTree() *RBTree {return &RBTree{}
}
func (t *RBTree) Search (v int64) *Node {
n := t.root
for n != nil {
if v < n.v {n = n.l} else if v >n.v {n = n.r} else {return n}
}
return nil
}
func (t *RBTree) Insert(v int64) {
if t.root == nil {t.root = &Node{v:v, c:BLACK}
return
}
n := t.root
addNode := &Node{v:v, c:RED}
var np *Node
for n != nil {if (v == n.v) {return}
np = n
if v < n.v {n = n.l} else {n = n.r}
}
addNode.p = np
if v < np.v {np.l = addNode} else {np.r = addNode}
// 验证规定
t.insertFix(addNode)
}
func (t *RBTree) insertFix(n *Node) {for !isBlack(n.p) {uncle := n.p.getBrother()
if !isBlack(uncle) {
n.p.c = BLACK
uncle.c = BLACK
uncle.p.c = RED
n = n.p.p
t.root.c = BLACK
continue
}
if n.p == n.p.p.l {
if n == n.p.l {
n.p.p.c = RED
n.p.c = BLACK
n = n.p.p
n.rTurn(t)
t.root.c = BLACK
continue
}
n = n.p
n.lTurn(t)
t.root.c = BLACK
continue
}
if n == n.p.r {
n.p.p.c = RED
n.p.c = BLACK
n = n.p.p
n.lTurn(t)
t.root.c = BLACK
continue
}
n = n.p
n.rTurn(t)
t.root.c = BLACK
}
}
func (t *RBTree) Del(v int64) {n := t.Search(v)
np := n.p
if n == nil {return}
var revise string
if n.l == nil && n.r == nil {revise = "none"} else if np == nil {revise = "root"} else if n == np.l {revise = "left"} else if n == np.r {revise = "right"}
// 内含递归
n.del(t)
if isBlack(n) {
if revise == "root" {t.delFix(t.root)
} else if revise == "left" {t.delFix(np.l)
} else if revise == "right" {t.delFix(np.r)
}
}
}
func (t *RBTree) delFix(n *Node) {
var b *Node
for n != t.root && isBlack(n) {
// 删除节点为左节点,存在兄弟节点
if n.p.l == n && n.p.r != nil {
b = n.p.r
if !isBlack(b) {
b.c = BLACK
n.p.c = RED
n.p.lTurn(t)
} else if isBlack(b) && b.l != nil && isBlack(b.l) && b.r != nil && isBlack(b.r){
b.c = RED
n = n.p
} else if isBlack(b) && b.l != nil && !isBlack(b.l) && b.r != nil && isBlack(b.r) {
b.c = RED
b.l.c = BLACK
b.rTurn(t)
} else if isBlack(b) && b.r != nil && !isBlack(b.r) {
b.c = RED
b.r.c = BLACK
b.p.c = BLACK
b.p.lTurn(t)
n = t.root
}
} else if n.p.r == n && n.p.l != nil {
b = n.p.l
if !isBlack(b) {
b.c = BLACK
n.p.c = RED
n.p.rTurn(t)
} else if isBlack(b) && b.l != nil && isBlack(b.l) && b.r != nil && isBlack(b.r) {
b.c = RED
n = n.p
} else if isBlack(b) && b.l != nil && isBlack(b.l) && b.r != nil && !isBlack(b.r) {
b.c = RED
b.r.c = BLACK
b.lTurn(t)
} else if isBlack(b) && b.l != nil && !isBlack(b.l) {
b.c = RED
b.l.c = BLACK
b.p.c = BLACK
b.p.rTurn(t)
n = t.root
}
} else {return}
}
}
func (t *RBTree) min(n *Node) *Node {
if n == nil {return nil}
for n.l != nil {n = n.l}
return n
}
func (t *RBTree) max(n *Node) *Node {
if n == nil {return nil}
for n.r != nil {n =n.r}
return n
}
// 获取前驱节点
func (t *RBTree) getPredecessor(n *Node) *Node {
if n == nil {return nil}
// 小分支里的最大节点
if n.l != nil {return t.max(n.l)
}
// 如果不存在又节点就从父节点中找
for {
if n.p == nil {break}
if n == n.p.r {return n.p}
n = n.p
}
return nil
}
// 获取继承节点
func (t *RBTree) getSuccessor(n *Node) *Node {
if n == nil {return nil}
// 大分支里的最小节点
if n.r != nil {return t.min(n.r)
}
// 如果不存在又节点就从父节点中找
for {
if n.p == nil {break}
if n == n.p.l {return n.p}
n = n.p
}
return nil
}
/////////////////////////////////////
func isBlack(n *Node) bool {
if n == nil {return true}
return n.c == BLACK
}
func setColor(n *Node, color bool) {
if n == nil {return}
n.c = color
}
////////////////////////////////////////////////////////////
type Node struct {
l, r, p *Node
v int64
c bool
}
func (n *Node) lTurn(t *RBTree) {
if n == nil || n.r == nil {return}
np := n.p
nr := n.r
// 父节点解决
if np != nil {
if n == np.l {np.l = nr} else {np.r = nr}
} else {t.root = nr}
// 解决本人关系
n.p = nr
n.r = nr.l
if n.r != nil {
// 左孙节点解决
n.r.p = n
}
// 右子节点解决
nr.l = n
nr.p = np
}
func (n *Node) rTurn(t *RBTree) {
if n == nil || n.l == nil {return}
nl := n.l
np := n.p
// 父节点解决
if np != nil {
if n == np.l {np.l = nl} else {np.r = nl}
} else {t.root = nl}
// 解决本人关系
n.p = nl
n.l = nl.r
if n.l != nil {
// 右孙节点解决
n.l.p = n
}
// 左子节点解决
nl.r = n
nl.p = np
}
func (n *Node) getBrother() *Node {
if n.p == nil {return nil}
if n.p.l == n {return n.p.r}
if n.p.r == n {return n.p.l}
return nil
}
func (n *Node) del(t *RBTree) {
np := n.p
if n.l == nil && n.r == nil {
// 节点为尾部节点(不存在子节点)// 根节点、左尾节点、右尾节点
if n == t.root {t.root = nil} else if np.l == n {np.l = nil} else {np.r = nil}
} else if n.l != nil && n.r == nil {
// 存在左子节点
if n == t.root {
// 根节点
n.l.p = nil
t.root = n.l
} else {
n.l.p = np
if np.l == n {np.l = n.l} else {np.r = n.l}
}
} else if n.l == nil && n.r != nil {
// 存在右子节点
if n == t.root {
n.r.p = nil
t.root = n.r
} else {
n.r.p = np
if np.l == n {np.l = n.r} else {np.r = n.r}
}
} else {
// 存在两个节点
successor := t.getSuccessor(n)
n.v = successor.v
n.c = successor.c
// 递归删除
successor.del(t)
}
}
func main() {btree := NewRBTree()
btree.Insert(1)
btree.Insert(21)
btree.Insert(3)
btree.Insert(4)
btree.Insert(5)
btree.Insert(6)
btree.Insert(7)
btree.Insert(8)
btree.Insert(9)
btree.Insert(10)
btree.Insert(1)
fmt.Println(btree.Search(11))
fmt.Println(btree.Search(10))
fmt.Println(btree.Search(9))
fmt.Println(btree.Search(8))
fmt.Println(btree.Search(7))
fmt.Println(btree.Search(6))
fmt.Println(btree.Search(5))
fmt.Println(btree.Search(4))
fmt.Println(btree.Search(3))
fmt.Println(btree.Search(21))
fmt.Println(btree.Search(1))
fmt.Println(btree.root)
}