diff --git a/hamilton/plugins/polars_lazyframe_extensions.py b/hamilton/plugins/polars_lazyframe_extensions.py index 6e53e478d..4fd6040bc 100644 --- a/hamilton/plugins/polars_lazyframe_extensions.py +++ b/hamilton/plugins/polars_lazyframe_extensions.py @@ -55,7 +55,7 @@ from hamilton import registry from hamilton.io import utils -from hamilton.io.data_adapters import DataLoader +from hamilton.io.data_adapters import DataLoader, DataSaver DATAFRAME_TYPE = pl.LazyFrame COLUMN_TYPE = pl.Expr @@ -297,6 +297,168 @@ def name(cls) -> str: return "feather" +@dataclasses.dataclass +class PolarsSinkParquetWriter(DataSaver): + """ + Class specifically to handle writing parquet files with Polars LazyFrame. + Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_parquet.html + """ + + path: Union[str, Path] + # kwargs: + compression: str = "zstd" + compression_level: Optional[int] = None + statistics: bool = False + row_group_size: Optional[int] = None + data_page_size: Optional[int] = None + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_writing_kwargs(self): + kwargs = {} + if self.compression is not None: + kwargs["compression"] = self.compression + if self.compression_level is not None: + kwargs["compression_level"] = self.compression_level + if self.statistics is not None: + kwargs["statistics"] = self.statistics + if self.row_group_size is not None: + kwargs["row_group_size"] = self.row_group_size + if self.data_page_size is not None: + kwargs["data_page_size"] = self.data_page_size + return kwargs + + def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + data.sink_parquet(self.path, **self._get_writing_kwargs()) + metadata = utils.get_file_metadata(self.path) + return metadata + + @classmethod + def name(cls) -> str: + return "parquet" + + +@dataclasses.dataclass +class PolarsSinkCSVWriter(DataSaver): + """ + Class specifically to handle writing CSV files with Polars LazyFrame. + Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_csv.html + """ + + path: Union[str, Path] + # kwargs: + include_bom: bool = False + include_header: bool = True + separator: str = "," + line_terminator: str = "\n" + quote_char: str = '"' + batch_size: int = 1024 + datetime_format: Optional[str] = None + date_format: Optional[str] = None + time_format: Optional[str] = None + float_precision: Optional[int] = None + null_value: Optional[str] = None + quote_style: Optional[str] = None + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_writing_kwargs(self): + kwargs = {} + if self.include_bom is not None: + kwargs["include_bom"] = self.include_bom + if self.include_header is not None: + kwargs["include_header"] = self.include_header + if self.separator is not None: + kwargs["separator"] = self.separator + if self.line_terminator is not None: + kwargs["line_terminator"] = self.line_terminator + if self.quote_char is not None: + kwargs["quote_char"] = self.quote_char + if self.batch_size is not None: + kwargs["batch_size"] = self.batch_size + if self.datetime_format is not None: + kwargs["datetime_format"] = self.datetime_format + if self.date_format is not None: + kwargs["date_format"] = self.date_format + if self.time_format is not None: + kwargs["time_format"] = self.time_format + if self.float_precision is not None: + kwargs["float_precision"] = self.float_precision + if self.null_value is not None: + kwargs["null_value"] = self.null_value + if self.quote_style is not None: + kwargs["quote_style"] = self.quote_style + return kwargs + + def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + data.sink_csv(self.path, **self._get_writing_kwargs()) + metadata = utils.get_file_metadata(self.path) + return metadata + + @classmethod + def name(cls) -> str: + return "csv" + + +@dataclasses.dataclass +class PolarsSinkIPCWriter(DataSaver): + """ + Class specifically to handle writing IPC/Feather files with Polars LazyFrame. + Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ipc.html + """ + + path: Union[str, Path] + # kwargs: + compression: Optional[str] = "zstd" + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_writing_kwargs(self): + kwargs = {} + if self.compression is not None: + kwargs["compression"] = self.compression + return kwargs + + def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + data.sink_ipc(self.path, **self._get_writing_kwargs()) + metadata = utils.get_file_metadata(self.path) + return metadata + + @classmethod + def name(cls) -> str: + return "ipc" + + +@dataclasses.dataclass +class PolarsSinkNDJSONWriter(DataSaver): + """ + Class specifically to handle writing NDJSON files with Polars LazyFrame. + Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ndjson.html + Note: Load support for NDJSON is not yet implemented. + """ + + path: Union[str, Path] + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + data.sink_ndjson(self.path) + metadata = utils.get_file_metadata(self.path) + return metadata + + @classmethod + def name(cls) -> str: + return "ndjson" + + def register_data_loaders(): """Function to register the data loaders for this extension.""" for loader in [ @@ -308,3 +470,17 @@ def register_data_loaders(): register_data_loaders() + + +def register_data_savers(): + """Function to register the data savers for this extension.""" + for saver in [ + PolarsSinkParquetWriter, + PolarsSinkCSVWriter, + PolarsSinkIPCWriter, + PolarsSinkNDJSONWriter, + ]: + registry.register_adapter(saver) + + +register_data_savers() diff --git a/tests/plugins/test_polars_lazyframe_extensions.py b/tests/plugins/test_polars_lazyframe_extensions.py index 71715e6d8..2d3e01b22 100644 --- a/tests/plugins/test_polars_lazyframe_extensions.py +++ b/tests/plugins/test_polars_lazyframe_extensions.py @@ -27,6 +27,10 @@ PolarsScanCSVReader, PolarsScanFeatherReader, PolarsScanParquetReader, + PolarsSinkCSVWriter, + PolarsSinkIPCWriter, + PolarsSinkNDJSONWriter, + PolarsSinkParquetWriter, ) from hamilton.plugins.polars_post_1_0_0_extensions import ( PolarsAvroReader, @@ -193,3 +197,75 @@ def test_polars_spreadsheet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: assert write_kwargs["include_header"] is True assert "raise_if_empty" in read_kwargs assert read_kwargs["raise_if_empty"] is True + + +def test_polars_sink_parquet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test_sink.parquet" + + writer = PolarsSinkParquetWriter(path=file) + kwargs = writer._get_writing_kwargs() + metadata = writer.save_data(df) + + # Read back the data to verify it was written correctly + reader = PolarsScanParquetReader(file=file) + df2, _ = reader.load_data(pl.LazyFrame) + + assert PolarsSinkParquetWriter.applicable_types() == [pl.LazyFrame] + assert kwargs["compression"] == "zstd" + assert kwargs["statistics"] is False + assert file.exists() + assert metadata["file_metadata"]["path"] == str(file) + assert_frame_equal(df.collect(), df2.collect()) + + +def test_polars_sink_csv(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test_sink.csv" + + writer = PolarsSinkCSVWriter(path=file) + kwargs = writer._get_writing_kwargs() + metadata = writer.save_data(df) + + # Read back the data to verify it was written correctly + reader = PolarsScanCSVReader(file=file) + df2, _ = reader.load_data(pl.LazyFrame) + + assert PolarsSinkCSVWriter.applicable_types() == [pl.LazyFrame] + assert kwargs["separator"] == "," + assert kwargs["include_header"] is True + assert kwargs["batch_size"] == 1024 + assert file.exists() + assert metadata["file_metadata"]["path"] == str(file) + assert_frame_equal(df.collect(), df2.collect()) + + +def test_polars_sink_ipc(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test_sink.ipc" + + writer = PolarsSinkIPCWriter(path=file) + kwargs = writer._get_writing_kwargs() + metadata = writer.save_data(df) + + # Read back the data to verify it was written correctly + reader = PolarsScanFeatherReader(source=file) + df2, _ = reader.load_data(pl.LazyFrame) + + assert PolarsSinkIPCWriter.applicable_types() == [pl.LazyFrame] + assert kwargs["compression"] == "zstd" + assert file.exists() + assert metadata["file_metadata"]["path"] == str(file) + assert_frame_equal(df.collect(), df2.collect()) + + +def test_polars_sink_ndjson(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test_sink.ndjson" + + writer = PolarsSinkNDJSONWriter(path=file) + metadata = writer.save_data(df) + + # Read back the data to verify it was written correctly + df2 = pl.read_ndjson(file) + + assert PolarsSinkNDJSONWriter.applicable_types() == [pl.LazyFrame] + assert file.exists() + assert metadata["file_metadata"]["path"] == str(file) + assert_frame_equal(df.collect(), df2)