curia.modeling.metrics module

class curia.modeling.metrics.MetricsInterface

Bases: object

metric_factories = {'auc': <function _get_auc>, 'decile_gain': <function _get_decile_gain>, 'r2': <function _get_r2>, 'shap': <function _get_shap>, 'smd': <function _get_smd>}

A class to simplify the model job output extraction process.

Much of this code was originally in: https://dbc-5c784d6d-d7a7.cloud.databricks.com/?o=1189584077923919#notebook/3744383587579554/command/3744383587579629

classmethod get_basic_metrics(session: Session, model_job_id: str) Dict[str, float]

Extracting the most basic metrics from the model job as a dictionary.

Parameters:
  • session (Session) – The session object to use.

  • model_job_id (str) – The model job id from which to get the metrics

Returns:

_description_

Return type:

Dict[str, float or int]

classmethod get_registered_metrics() List[str]

Get the names of metrics that can be used with the get_metrics function.

Returns:

The list of metric names.

Return type:

List[str]

classmethod get_metrics(session: Session, metric_names: List[str], model_job_id: str) Dict[str, float]

Get the metrics named in ‘metric_names’ from the model job.

Parameters:
  • session (Session) – The session object to use.

  • metric_names (List[str]) – The metrics to pull.

  • model_job_id (str) – The model job from which to get output.

Raises:

ValueError – If any metric is not a registered metric

Returns:

The metrics requested.

Return type:

Dict[str, float or int]

curia.modeling.metrics.register_metric(metric_name: str)

Register a metric with the instance.

Parameters:

metric_name (str) – The name of the metric.