diff --git a/drawdata/__init__.py b/drawdata/__init__.py index 700ae14..fd17529 100644 --- a/drawdata/__init__.py +++ b/drawdata/__init__.py @@ -47,16 +47,16 @@ def data_as_polars(self): def data_as_X_y(self): import numpy as np - colors = [_['color'] for _ in self.data] + labels = [_['label'] for _ in self.data] # Updated to use 'label' instead of 'color' # Assume that we're dealing with regression in this case - if np.unique(colors).shape[0] == 1: + if np.unique(labels).shape[0] == 1: X = np.array([[_['x']] for _ in self.data]) y = np.array([_['y'] for _ in self.data]) return X, y X = np.array([[_['x'], _['y']] for _ in self.data]) - return X, colors + return X, labels class BarWidget(anywidget.AnyWidget):