curia.modeling.metrics module¶
- class curia.modeling.metrics.MetricsInterface¶
Bases:
object
- metric_factories = {'auc': <function _get_auc>, 'decile_gain': <function _get_decile_gain>, 'r2': <function _get_r2>, 'shap': <function _get_shap>, 'smd': <function _get_smd>}¶
A class to simplify the model job output extraction process.
Much of this code was originally in: https://dbc-5c784d6d-d7a7.cloud.databricks.com/?o=1189584077923919#notebook/3744383587579554/command/3744383587579629
- classmethod get_basic_metrics(session: Session, model_job_id: str) Dict[str, float] ¶
Extracting the most basic metrics from the model job as a dictionary.
- Parameters:
session (Session) – The session object to use.
model_job_id (str) – The model job id from which to get the metrics
- Returns:
_description_
- Return type:
Dict[str, float or int]
- classmethod get_registered_metrics() List[str] ¶
Get the names of metrics that can be used with the get_metrics function.
- Returns:
The list of metric names.
- Return type:
List[str]
- classmethod get_metrics(session: Session, metric_names: List[str], model_job_id: str) Dict[str, float] ¶
Get the metrics named in ‘metric_names’ from the model job.
- Parameters:
session (Session) – The session object to use.
metric_names (List[str]) – The metrics to pull.
model_job_id (str) – The model job from which to get output.
- Raises:
ValueError – If any metric is not a registered metric
- Returns:
The metrics requested.
- Return type:
Dict[str, float or int]
- curia.modeling.metrics.register_metric(metric_name: str)¶
Register a metric with the instance.
- Parameters:
metric_name (str) – The name of the metric.