A better categorical encoder

Following up on the learnings from the Kaggle Red Hat business challenge attempt, and inspired by the work of others, I implemented a scikit-learn pipeline transformer that gracefully handles categorical variables with tons of values. It works by having a maximum number of columns a given variable will expand out into (say, 20) and then one-hot encoding as many of the frequently occurring values as possible, leaving enough space to binary encode the rest.

I considered that perhaps this should be broken down into two transformers: one for one-hot, one for binary encoding, but I think there is enough commonality that placing them together makes sense. For instance, in both cases you need a frequency count of the unique values showing up in a column. And the concern about how many one-hot columns to create is related to the number of columns to leave available for binary encoding.

In the absence of me publishing my project open source yet, here is an example use case:

OmniEncoder(max_cols=4).fit_transform(
...             pd.DataFrame(
...                 columns=['color'],
...                 data=[
...                     ['green'],
...                     ['red'],
...                     ['green'],
...                     ['yellow'],
...                     ['orange'],
...                     ['red'],
...                     ['purple']
...                 ]
...             )
...         )
   color_green  color_red  color_10  color_01
0            1          0         0         0
1            0          1         0         0
2            1          0         0         0
3            0          0         1         1
4            0          0         0         1
5            0          1         0         0
6            0          0         1         0

Given we need to encode 6 colors in a maximum of 4 columns, we only have room for 2 one-hot encoded columns, which are designated to the two most frequently occurring colors. The rest are then encoded within two binary columns.

and a here is the source for the encoder:

class OmniEncoder(BaseTransformer):
    """
    Encodes a categorical variable using no more than k columns. As many values as possible
    are one-hot encoded, the remaining are fit within a binary encoded set of columns.
    If necessary some are dropped (e.g if (#unique_values) > 2^k).

    In deciding which values to one-hot encode, those that appear more frequently are
    preferred.
    """
    def __init__(self, max_cols=20):
        self.column_infos = {}
        self.max_cols = max_cols
        if max_cols < 3 or max_cols > 100:
            raise ValueError("max_cols {} not within range(3, 100)".format(max_cols))

    def fit(self, X, y=None, **fit_params):
        self.column_infos = {col: self._column_info(X[col], self.max_cols) for col in X.columns}
        return self

    def transform(self, X, **transform_params):
        return pd.concat(
            [self._encode_column(X[col], self.max_cols, *self.column_infos[col]) for col in X.columns],
            axis=1
        )

    @staticmethod
    def _encode_column(col, max_cols, one_hot_vals, binary_encoded_vals):
        num_one_hot = len(one_hot_vals)
        num_bits = max_cols - num_one_hot if len(binary_encoded_vals) > 0 else 0

        binary_val_to_index = {val: idx + 1 for idx, val in enumerate(binary_encoded_vals)}

        bit_cols = [format(2 ** i, 'b').zfill(num_bits) for i in reversed(range(num_bits))]

        col_names = ["{}_{}".format(col.name, val) for val in one_hot_vals] + ["{}_{}".format(col.name, bit_col) for bit_col in bit_cols]

        def splat(v):
            v_one_hot = [1 if v == ohv else 0 for ohv in one_hot_vals]

            v_bits = [0]*num_bits
            if v in binary_val_to_index:
                v_bits = [int(b) for b in format(binary_val_to_index[v], 'b').zfill(num_bits)]

            return pd.Series(v_one_hot + v_bits)

        df = col.apply(splat)
        df.columns = col_names

        return df

    @staticmethod
    def _column_info(col, max_cols):
        """

        :param col: pd.Series
        :return: {'val': 44, 'val2': 4, ...}
        """
        val_counts = dict(col.value_counts())
        num_one_hot = OmniEncoder._num_onehot(len(val_counts), max_cols)
        return OmniEncoder._partition_one_hot(val_counts, num_one_hot)

    @staticmethod
    def _partition_one_hot(val_counts, num_one_hot):
        """
        Paritions the values in val counts into a list of values that should be
        one-hot encoded and a list of values that should be binary encoded.

        The `num_one_hot` most popular values are chosen to be one-hot encoded.

        :param val_counts: {'val': 433}
        :param num_one_hot: the number of elements to be one-hot encoded
        :return: ['val1', 'val2'], ['val55', 'val59']
        """
        one_hot_vals = [k for (k, count) in heapq.nlargest(num_one_hot, val_counts.items(), key=lambda t: t[1])]
        one_hot_vals_lookup = set(one_hot_vals)

        bin_encoded_vals = [val for val in val_counts if val not in one_hot_vals_lookup]

        return sorted(one_hot_vals), sorted(bin_encoded_vals)


    @staticmethod
    def _num_onehot(n, k):
        """
        Determines the number of onehot columns we can have to encode n values
        in no more than k columns, assuming we will binary encode the rest.

        :param n: The number of unique values to encode
        :param k: The maximum number of columns we have
        :return: The number of one-hot columns to use
        """
        num_one_hot = min(n, k)

        def num_bin_vals(num):
            if num == 0:
                return 0
            return 2 ** num - 1

        def capacity(oh):
            """
            Capacity given we are using `oh` one hot columns.
            """
            return oh + num_bin_vals(k - oh)

        while capacity(num_one_hot) < n and num_one_hot > 0:
            num_one_hot -= 1

        return num_one_hot