Source code for starboost.line_searchers
import abc
import numpy as np
__all__ = ['LeafLineSearcher']
class LineSearcher(abc.ABC):
@abc.abstractmethod
def fit(self, y_true, y_pred, gradient, direction):
pass
@abc.abstractmethod
def update(self, direction):
pass
[docs]class LeafLineSearcher(LineSearcher):
def __init__(self, update_leaf):
self.update_leaf = update_leaf
def fit(self, y_true, y_pred, gradient, direction):
leaves = np.unique(direction)
self.updates_ = {}
for leaf in leaves:
mask = direction == leaf
self.updates_[leaf] = self.update_leaf(y_true[mask], y_pred[mask], gradient[mask])
return self
def update(self, direction):
for leaf, update in self.updates_.items():
np.place(direction, direction == leaf, update)
return direction