Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions native-engine/datafusion-ext-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use datafusion::{common::Result, logical_expr::ScalarFunctionImplementation};
use datafusion_ext_commons::df_unimplemented_err;
Expand All @@ -39,51 +39,57 @@ pub fn create_auron_ext_function(
name: &str,
spark_partition_id: usize,
) -> Result<ScalarFunctionImplementation> {
macro_rules! cache {
($func:path) => {{
static CELL: OnceLock<ScalarFunctionImplementation> = OnceLock::new();
CELL.get_or_init(|| Arc::new($func)).clone()
}};
}
Comment on lines +42 to +47
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache macro has a critical bug: all invocations share the same static CELL variable. This means that once any function is cached in the first call, all subsequent function lookups will return that same cached function, regardless of which function was requested.

For example, if "Spark_NullIf" is called first, it will cache spark_null_if::spark_null_if. Then when "Spark_NullIfZero" is called, it will return the same spark_null_if::spark_null_if function instead of spark_null_if::spark_null_if_zero.

To fix this, each invocation needs its own unique static variable. This can be achieved by making the static variable name unique per function path, or by using a different caching approach such as a global HashMap with function names as keys.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In rust, statics inside functions or blocks are not global singletons sharing the same name; they are local singletons unique to that specific scope instantiation. In this case, it is to that specific matching arm.
Just running this test in AuronQuerySuite

test("my cache test") {
    withTable("my_cache_table") {
      sql("""
            |create table my_cache_table using parquet as
            |select col1, col2 from values ('a,A', '{"a":"1", "b":"2"}'), ('b,B', '{"a":"3", "b":"4"}'), ('c,C', '{"a":"5", "b":"6"}')
            |""".stripMargin)
      sql("""
            |select split(col1, ',')[0],
            |       split(col1, ',')[1],
            |       get_json_object(col2, '$.a'),
            |       get_json_object(col2, '$.b')
            |from my_cache_table
            |""".stripMargin).show()
    }
  }

we can see the following correct answer.

+---------------------+---------------------+--------------------------+--------------------------+
|split(col1, ,, -1)[0]|split(col1, ,, -1)[1]|get_json_object(col2, $.a)|get_json_object(col2, $.b)|
+---------------------+---------------------+--------------------------+--------------------------+
|                    a|                    A|                         1|                         2|
|                    b|                    B|                         3|                         4|
|                    c|                    C|                         5|                         6|
+---------------------+---------------------+--------------------------+--------------------------+

It can handle different ext function StringSplit, GetParsedJsonObject and ParseJson.

ProjectExec [
(spark_ext_function_Spark_StringSplit(#2@0, ,)).[1] AS #16, 
(spark_ext_function_Spark_StringSplit(#2@0, ,)).[2] AS #17, spark_ext_function_Spark_GetParsedJsonObject(spark_ext_function_Spark_ParseJson(#3@1), $.a) AS #18, 
spark_ext_function_Spark_GetParsedJsonObject(spark_ext_function_Spark_ParseJson(#3@1), $.b) AS #19
], schema=[#16:Utf8;N, #17:Utf8;N, #18:Utf8;N, #19:Utf8;N]

// auron ext functions, if used for spark should be start with 'Spark_',
// if used for flink should be start with 'Flink_',
// same to other engines.
Ok(match name {
"Placeholder" => Arc::new(|_| panic!("placeholder() should never be called")),
"Spark_NullIf" => Arc::new(spark_null_if::spark_null_if),
"Spark_NullIfZero" => Arc::new(spark_null_if::spark_null_if_zero),
"Spark_UnscaledValue" => Arc::new(spark_unscaled_value::spark_unscaled_value),
"Spark_MakeDecimal" => Arc::new(spark_make_decimal::spark_make_decimal),
"Spark_CheckOverflow" => Arc::new(spark_check_overflow::spark_check_overflow),
"Spark_Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
"Spark_XxHash64" => Arc::new(spark_hash::spark_xxhash64),
"Spark_Sha224" => Arc::new(spark_crypto::spark_sha224),
"Spark_Sha256" => Arc::new(spark_crypto::spark_sha256),
"Spark_Sha384" => Arc::new(spark_crypto::spark_sha384),
"Spark_Sha512" => Arc::new(spark_crypto::spark_sha512),
"Spark_MD5" => Arc::new(spark_crypto::spark_md5),
"Spark_GetJsonObject" => Arc::new(spark_get_json_object::spark_get_json_object),
"Spark_NullIf" => cache!(spark_null_if::spark_null_if),
"Spark_NullIfZero" => cache!(spark_null_if::spark_null_if_zero),
"Spark_UnscaledValue" => cache!(spark_unscaled_value::spark_unscaled_value),
"Spark_MakeDecimal" => cache!(spark_make_decimal::spark_make_decimal),
"Spark_CheckOverflow" => cache!(spark_check_overflow::spark_check_overflow),
"Spark_Murmur3Hash" => cache!(spark_hash::spark_murmur3_hash),
"Spark_XxHash64" => cache!(spark_hash::spark_xxhash64),
"Spark_Sha224" => cache!(spark_crypto::spark_sha224),
"Spark_Sha256" => cache!(spark_crypto::spark_sha256),
"Spark_Sha384" => cache!(spark_crypto::spark_sha384),
"Spark_Sha512" => cache!(spark_crypto::spark_sha512),
"Spark_MD5" => cache!(spark_crypto::spark_md5),
"Spark_GetJsonObject" => cache!(spark_get_json_object::spark_get_json_object),
"Spark_GetParsedJsonObject" => {
Arc::new(spark_get_json_object::spark_get_parsed_json_object)
cache!(spark_get_json_object::spark_get_parsed_json_object)
}
"Spark_ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),
"Spark_MakeArray" => Arc::new(spark_make_array::array),
"Spark_StringSpace" => Arc::new(spark_strings::string_space),
"Spark_StringRepeat" => Arc::new(spark_strings::string_repeat),
"Spark_StringSplit" => Arc::new(spark_strings::string_split),
"Spark_StringConcat" => Arc::new(spark_strings::string_concat),
"Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
"Spark_StringLower" => Arc::new(spark_strings::string_lower),
"Spark_StringUpper" => Arc::new(spark_strings::string_upper),
"Spark_InitCap" => Arc::new(spark_initcap::string_initcap),
"Spark_Year" => Arc::new(spark_dates::spark_year),
"Spark_Month" => Arc::new(spark_dates::spark_month),
"Spark_Day" => Arc::new(spark_dates::spark_day),
"Spark_Quarter" => Arc::new(spark_dates::spark_quarter),
"Spark_Hour" => Arc::new(spark_dates::spark_hour),
"Spark_Minute" => Arc::new(spark_dates::spark_minute),
"Spark_Second" => Arc::new(spark_dates::spark_second),
"Spark_BrickhouseArrayUnion" => Arc::new(brickhouse::array_union::array_union),
"Spark_Round" => Arc::new(spark_round::spark_round),
"Spark_BRound" => Arc::new(spark_bround::spark_bround),
"Spark_ParseJson" => cache!(spark_get_json_object::spark_parse_json),
"Spark_MakeArray" => cache!(spark_make_array::array),
"Spark_StringSpace" => cache!(spark_strings::string_space),
"Spark_StringRepeat" => cache!(spark_strings::string_repeat),
"Spark_StringSplit" => cache!(spark_strings::string_split),
"Spark_StringConcat" => cache!(spark_strings::string_concat),
"Spark_StringConcatWs" => cache!(spark_strings::string_concat_ws),
"Spark_StringLower" => cache!(spark_strings::string_lower),
"Spark_StringUpper" => cache!(spark_strings::string_upper),
"Spark_InitCap" => cache!(spark_initcap::string_initcap),
"Spark_Year" => cache!(spark_dates::spark_year),
"Spark_Month" => cache!(spark_dates::spark_month),
"Spark_Day" => cache!(spark_dates::spark_day),
"Spark_Quarter" => cache!(spark_dates::spark_quarter),
"Spark_Hour" => cache!(spark_dates::spark_hour),
"Spark_Minute" => cache!(spark_dates::spark_minute),
"Spark_Second" => cache!(spark_dates::spark_second),
"Spark_BrickhouseArrayUnion" => cache!(brickhouse::array_union::array_union),
"Spark_Round" => cache!(spark_round::spark_round),
"Spark_BRound" => cache!(spark_bround::spark_bround),
"Spark_NormalizeNanAndZero" => {
Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero)
cache!(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero)
}
"Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan),
"Spark_IsNaN" => cache!(spark_isnan::spark_isnan),
_ => df_unimplemented_err!("spark ext function not implemented: {name}")?,
})
}
Loading