Source code for bamt.nodes.discrete_node

import random
from itertools import product
from typing import Type, Dict, Union, List

import numpy as np
from pandas import DataFrame, crosstab

from .base import BaseNode
from .schema import DiscreteParams


[docs] class DiscreteNode(BaseNode): """ Main class of Discrete Node """ def __init__(self, name): super(DiscreteNode, self).__init__(name) self.type = "Discrete"
[docs] def fit_parameters(self, data: DataFrame, num_workers: int = 1): """ Train params for Discrete Node data: DataFrame to train on num_workers: number of Parallel Workers Method returns probas dict with the following format {[<combinations>: value]} and vals, list of appeared values in combinations """ def worker(node: Type[BaseNode]) -> DiscreteParams: parents = node.disc_parents + node.cont_parents dist = data[node.name].value_counts(normalize=True).sort_index() vals = [str(i) for i in dist.index.to_list()] if not parents: cprob = dist.to_list() else: cprob = { str([str(i) for i in comb]): [1 / len(vals) for _ in vals] for comb in product(*[data[p].unique() for p in parents]) } conditional_dist = crosstab( data[node.name].to_list(), [data[p] for p in parents], normalize="columns", ).T tight_form = conditional_dist.to_dict("tight") for comb, probs in zip(tight_form["index"], tight_form["data"]): if len(parents) > 1: cprob[str([str(i) for i in comb])] = probs else: cprob[f"['{comb}']"] = probs return {"cprob": cprob, "vals": vals} # pool = ThreadPoolExecutor(num_workers) # future = pool.submit(worker, self) result = worker(self) return result
[docs] @staticmethod def get_dist(node_info, pvals): if not pvals: return node_info["cprob"] else: # noinspection PyTypeChecker return node_info["cprob"][str(pvals)]
[docs] def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: """ Return value from discrete node params: node_info: nodes info from distributions pvals: parent values """ vals = node_info["vals"] dist = np.array(self.get_dist(node_info, pvals)) cumulative_dist = np.cumsum(dist) rand = np.random.random() rindex = np.searchsorted(cumulative_dist, rand) return vals[rindex]
[docs] @staticmethod def searchsorted_per_row(row, rand_num): return np.searchsorted(row, rand_num, side="right")
[docs] def vectorized_choose( self, node_info: Dict[str, Union[float, str]], pvals_array: np.ndarray, n_samples: int, ) -> np.ndarray: """ Vectorized method to return values from a discrete node. params: node_info: node's info from distributions pvals_array: array of parent values, each row corresponds to a set of parent values n_samples: number of samples to generate """ vals = node_info["vals"] # Generate a matrix of distributions if pvals_array is None or len(pvals_array) == 0: # Handle the case with no parent nodes dist = np.array(node_info["cprob"]) dist_matrix = np.tile(dist, (n_samples, 1)) else: # Ensure pvals_array is limited to current batch size pvals_array = pvals_array[:n_samples] # Compute distribution for each set of parent values dist_matrix = np.array( [self.get_dist(node_info, pvals.tolist()) for pvals in pvals_array] ) # Ensure that dist_matrix is 2D if dist_matrix.ndim == 1: dist_matrix = dist_matrix.reshape(1, -1) # Generate cumulative distributions cumulative_dist_matrix = np.cumsum(dist_matrix, axis=1) random_nums = np.random.rand(n_samples) # Apply searchsorted across each row indices = np.apply_along_axis( self.searchsorted_per_row, 1, cumulative_dist_matrix, random_nums ) if indices.ndim > 1: indices = indices.flatten() sampled_values = np.array(vals)[indices] return sampled_values
[docs] @staticmethod def predict(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: """function for prediction based on evidence values in discrete node Args: node_info (Dict[str, Union[float, str]]): parameters of node pvals (List[str]): values in parents nodes Returns: str: prediction """ vals = node_info["vals"] disct = [] if not pvals: dist = node_info["cprob"] else: # noinspection PyTypeChecker dist = node_info["cprob"][str(pvals)] max_value = max(dist) indices = [index for index, value in enumerate(dist) if value == max_value] max_ind = 0 if len(indices) == 1: max_ind = indices[0] else: max_ind = random.choice(indices) return vals[max_ind]