112 Zeilen
2.3 KiB
Go
112 Zeilen
2.3 KiB
Go
package weightedrandom
|
|
|
|
import "sort"
|
|
|
|
var (
|
|
_ sort.Interface = &wflist[int]{}
|
|
)
|
|
|
|
type wflist[T any] []*weightitem[T]
|
|
|
|
func (wfl wflist[T]) Less(i, j int) bool {
|
|
return wfl[i].adjusted < wfl[j].adjusted
|
|
}
|
|
|
|
func (wfl wflist[T]) Len() int {
|
|
return len(wfl)
|
|
}
|
|
|
|
func (wfl wflist[T]) Swap(i, j int) {
|
|
wfl[i], wfl[j] = wfl[j], wfl[i]
|
|
}
|
|
|
|
// WeightRandom is the Type for an weighted random list
|
|
type WeightRandom[T any] struct {
|
|
data wflist[T]
|
|
weightFunc AdjustFunc
|
|
randomFunc RandomFunc
|
|
sorted bool
|
|
totalWeight int
|
|
adjustedWeight int
|
|
sorti sort.Interface
|
|
accessfunc AccessModifier
|
|
}
|
|
|
|
func (wr *WeightRandom[T]) reweight() {
|
|
wr.updatetotal()
|
|
for _, item := range wr.data {
|
|
item.adjusted = wr.weightFunc(item.weight, wr.totalWeight)
|
|
}
|
|
}
|
|
|
|
func (wr *WeightRandom[T]) sort() {
|
|
if wr.sorted {
|
|
return
|
|
}
|
|
if wr.sorti == nil {
|
|
wr.sorti = sort.Reverse(&wr.data)
|
|
}
|
|
wr.reweight()
|
|
sort.Stable(wr.sorti)
|
|
wr.sorted = true
|
|
}
|
|
|
|
func (wr *WeightRandom[T]) updatetotal() {
|
|
newtotal := 0
|
|
newadjusted := 0
|
|
for _, item := range wr.data {
|
|
newtotal += item.weight
|
|
|
|
newadjusted += item.adjusted
|
|
}
|
|
wr.adjustedWeight = newadjusted
|
|
wr.totalWeight = newtotal
|
|
}
|
|
|
|
// AddElement adds an single Element to the list
|
|
func (wr *WeightRandom[T]) AddElement(item T, weight int) {
|
|
wr.sorted = false
|
|
wr.totalWeight += weight
|
|
wr.data = append(wr.data, &weightitem[T]{weight: weight, value: item})
|
|
}
|
|
|
|
// AddElements adds an list of elements to the list
|
|
func (wr *WeightRandom[T]) AddElements(weight int, items ...T) {
|
|
for _, item := range items {
|
|
wr.AddElement(item, weight)
|
|
}
|
|
wr.reweight()
|
|
wr.sort()
|
|
}
|
|
|
|
// Get returns an random value of the list of items or the null value of the type in case no element was added
|
|
func (wr *WeightRandom[T]) Get() (t T) {
|
|
if !wr.sorted {
|
|
wr.reweight()
|
|
wr.sort()
|
|
}
|
|
pos := wr.randomFunc(wr.data.Len())
|
|
for _, item := range wr.data {
|
|
pos -= item.adjusted
|
|
if pos <= 0 {
|
|
if wr.accessfunc != nil {
|
|
wr.sorted = false
|
|
item.weight = wr.accessfunc(item.weight)
|
|
}
|
|
return item.value
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Cleanup deletes the items where the the function returns true
|
|
func (wr *WeightRandom[T]) Cleanup(f func(T) bool) {
|
|
nl := make(wflist[T], 0, wr.data.Len())
|
|
for _, item := range wr.data {
|
|
if !f(item.value) {
|
|
nl = append(nl, item)
|
|
}
|
|
}
|
|
wr.data = nl
|
|
}
|