From b5c0657f8cd99c7842052f55c2d12e7561333c30 Mon Sep 17 00:00:00 2001 From: userzhy <48518279+userzhy@users.noreply.github.com> Date: Fri, 19 Dec 2025 14:48:29 +0000 Subject: [PATCH 1/2] feat(go): implement union type for xlang serialization Add support for tagged union types in Go Fory, enabling cross-language serialization compatibility with other languages like C++, Java, Python, and Rust that already support union/variant types. Changes: - Add UNION (38) and NONE (39) type constants to types.go - Implement Union struct and unionSerializer in union.go - Add RegisterUnionType API to Fory for registering union alternatives - Add comprehensive tests in union_test.go The implementation follows the same binary protocol as other languages: 1. Write variant index (varuint32) 2. In xlang mode, write type info for the active alternative 3. Write the value data using the alternative's serializer Fixes #3031 --- go/fory/types.go | 4 + go/fory/union.go | 315 ++++++++++++++++++++++++++++++++++++++++++ go/fory/union_test.go | 227 ++++++++++++++++++++++++++++++ 3 files changed, 546 insertions(+) create mode 100644 go/fory/union.go create mode 100644 go/fory/union_test.go diff --git a/go/fory/types.go b/go/fory/types.go index e16a7545ab..8ccfcaa064 100644 --- a/go/fory/types.go +++ b/go/fory/types.go @@ -98,6 +98,10 @@ const ( FLOAT32_ARRAY = 36 // FLOAT64_ARRAY one dimensional float64 array FLOAT64_ARRAY = 37 + // UNION a tagged union type that can hold one of several alternative types + UNION = 38 + // NONE represents an empty/unit value with no data (e.g., for empty union alternatives) + NONE = 39 // UINT8 Unsigned 8-bit little-endian integer UINT8 = 64 diff --git a/go/fory/union.go b/go/fory/union.go new file mode 100644 index 0000000000..ea3fe694dd --- /dev/null +++ b/go/fory/union.go @@ -0,0 +1,315 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "fmt" + "reflect" +) + +// Union represents a tagged union type that can hold one of several alternative types. +// It's equivalent to Rust's enum, C++'s std::variant, or Python's typing.Union. +// +// The Value field holds the actual value, which must be one of the types specified +// when registering the Union type. +// +// Example usage: +// +// // Create a union that can hold int32 or string +// union := fory.Union{Value: int32(42)} +// // or +// union := fory.Union{Value: "hello"} +type Union struct { + Value interface{} +} + +// NewUnion creates a new Union with the given value. +func NewUnion(value interface{}) Union { + return Union{Value: value} +} + +// IsNil returns true if the union holds no value. +func (u Union) IsNil() bool { + return u.Value == nil +} + +// unionSerializer serializes Union types. +// +// Serialization format: +// 1. Write variant index (varuint32) - identifies which alternative type is active +// 2. In xlang mode, write type info for the active alternative +// 3. Write the value data using the alternative's serializer +type unionSerializer struct { + type_ reflect.Type + alternativeTypes []reflect.Type + typeResolver *TypeResolver + alternativeTypeInfo []*TypeInfo +} + +// newUnionSerializer creates a new serializer for Union types with the specified alternatives. +// The alternativeTypes slice defines the allowed types in order - the index is used as the variant index. +func newUnionSerializer(typeResolver *TypeResolver, alternativeTypes []reflect.Type) *unionSerializer { + typeInfos := make([]*TypeInfo, len(alternativeTypes)) + return &unionSerializer{ + type_: reflect.TypeOf(Union{}), + alternativeTypes: alternativeTypes, + typeResolver: typeResolver, + alternativeTypeInfo: typeInfos, + } +} + +// findAlternativeIndex finds the index of the type that matches the given value. +// Returns -1 if no match is found. +func (s *unionSerializer) findAlternativeIndex(value reflect.Value) int { + if !value.IsValid() || (value.Kind() == reflect.Interface && value.IsNil()) { + return -1 + } + + valueType := value.Type() + if valueType.Kind() == reflect.Interface { + valueType = value.Elem().Type() + } + + for i, altType := range s.alternativeTypes { + if valueType == altType { + return i + } + // Also check if the value is assignable to the alternative type + if valueType.AssignableTo(altType) { + return i + } + // For pointer types, check the elem type + if valueType.Kind() == reflect.Ptr && altType.Kind() == reflect.Ptr { + if valueType.Elem() == altType.Elem() { + return i + } + } + } + return -1 +} + +func (s *unionSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, value reflect.Value) error { + buf := ctx.Buffer() + + // Get the Union value + var union Union + if value.Kind() == reflect.Ptr { + if value.IsNil() { + buf.WriteInt8(NullFlag) + return nil + } + union = value.Elem().Interface().(Union) + } else { + union = value.Interface().(Union) + } + + // Handle null union value + if union.Value == nil { + switch refMode { + case RefModeTracking, RefModeNullOnly: + buf.WriteInt8(NullFlag) + } + return nil + } + + // Write ref flag for non-null + switch refMode { + case RefModeTracking: + refWritten, err := ctx.RefResolver().WriteRefOrNull(buf, value) + if err != nil { + return err + } + if refWritten { + return nil + } + case RefModeNullOnly: + buf.WriteInt8(NotNullValueFlag) + } + + // Write type info if needed + if writeType { + buf.WriteVaruint32Small7(uint32(UNION)) + } + + return s.WriteData(ctx, value) +} + +func (s *unionSerializer) WriteData(ctx *WriteContext, value reflect.Value) error { + buf := ctx.Buffer() + + // Get the Union value + var union Union + if value.Kind() == reflect.Ptr { + union = value.Elem().Interface().(Union) + } else { + union = value.Interface().(Union) + } + + // Find which alternative type matches the value + innerValue := reflect.ValueOf(union.Value) + activeIndex := s.findAlternativeIndex(innerValue) + + if activeIndex < 0 { + return fmt.Errorf("union value type %T doesn't match any alternative in %v", union.Value, s.alternativeTypes) + } + + // Write the active variant index + buf.WriteVaruint32(uint32(activeIndex)) + + // Get the serializer for the active alternative + altType := s.alternativeTypes[activeIndex] + serializer, err := ctx.TypeResolver().getSerializerByType(altType, false) + if err != nil { + return fmt.Errorf("no serializer for union alternative type %v: %w", altType, err) + } + + // In xlang mode, write type info for the alternative + if ctx.TypeResolver().isXlang { + typeInfo, err := ctx.TypeResolver().getTypeInfo(innerValue, true) + if err != nil { + return err + } + if err := ctx.TypeResolver().WriteTypeInfo(buf, typeInfo); err != nil { + return err + } + } + + // Write the value data + return serializer.WriteData(ctx, innerValue) +} + +func (s *unionSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, value reflect.Value) error { + buf := ctx.Buffer() + + switch refMode { + case RefModeTracking: + refID, err := ctx.RefResolver().TryPreserveRefId(buf) + if err != nil { + return err + } + if int8(refID) < NotNullValueFlag { + obj := ctx.RefResolver().GetReadObject(refID) + if obj.IsValid() { + if value.Kind() == reflect.Ptr { + value.Elem().Set(obj) + } else { + value.Set(obj) + } + } + return nil + } + case RefModeNullOnly: + flag := buf.ReadInt8() + if flag == NullFlag { + return nil + } + } + + if readType { + typeId := buf.ReadVaruint32Small7() + if TypeId(typeId) != UNION { + return fmt.Errorf("expected UNION type id %d, got %d", UNION, typeId) + } + } + + return s.ReadData(ctx, s.type_, value) +} + +func (s *unionSerializer) ReadData(ctx *ReadContext, type_ reflect.Type, value reflect.Value) error { + buf := ctx.Buffer() + + // Read the stored variant index + storedIndex := buf.ReadVaruint32() + + // Validate index is within bounds + if int(storedIndex) >= len(s.alternativeTypes) { + return fmt.Errorf("union index out of bounds: %d (max: %d)", storedIndex, len(s.alternativeTypes)-1) + } + + // Get the alternative type + altType := s.alternativeTypes[storedIndex] + + // Get serializer for this alternative + serializer, err := ctx.TypeResolver().getSerializerByType(altType, false) + if err != nil { + return fmt.Errorf("no serializer for union alternative type %v: %w", altType, err) + } + + // In xlang mode, read type info for the alternative + if ctx.TypeResolver().isXlang { + // Read the type info - we need to pass a value for the ReadTypeInfo function + dummyValue := reflect.New(altType).Elem() + _, err := ctx.TypeResolver().ReadTypeInfo(buf, dummyValue) + if err != nil { + return err + } + } + + // Create a value to hold the alternative data + altValue := reflect.New(altType).Elem() + + // Read the value data + if err := serializer.ReadData(ctx, altType, altValue); err != nil { + return err + } + + // Set the union value + union := Union{Value: altValue.Interface()} + if value.Kind() == reflect.Ptr { + value.Elem().Set(reflect.ValueOf(union)) + } else { + value.Set(reflect.ValueOf(union)) + } + + return nil +} + +func (s *unionSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) error { + return s.Read(ctx, refMode, false, value) +} + +// RegisterUnionType registers a Union type with the specified alternative types. +// The alternative types are the types that the union can hold. +// Returns an error if registration fails. +// +// Example: +// +// f := fory.NewFory() +// err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) +// if err != nil { +// panic(err) +// } +func (f *Fory) RegisterUnionType(alternativeTypes ...reflect.Type) error { + if len(alternativeTypes) == 0 { + return fmt.Errorf("union must have at least one alternative type") + } + + unionType := reflect.TypeOf(Union{}) + serializer := newUnionSerializer(f.typeResolver, alternativeTypes) + + // Register the union type with the serializer + f.typeResolver.typeToSerializers[unionType] = serializer + + // Also register pointer type + ptrUnionType := reflect.PtrTo(unionType) + f.typeResolver.typeToSerializers[ptrUnionType] = &ptrToValueSerializer{ + valueSerializer: serializer, + } + + return nil +} diff --git a/go/fory/union_test.go b/go/fory/union_test.go new file mode 100644 index 0000000000..a858f3626b --- /dev/null +++ b/go/fory/union_test.go @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnionBasicTypes(t *testing.T) { + f := NewFory() + err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + require.NoError(t, err) + + // Test with int32 value + unionInt := Union{Value: int32(42)} + data, err := f.Serialize(unionInt) + require.NoError(t, err) + + var result Union + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, int32(42), result.Value) + + // Test with string value + unionStr := Union{Value: "hello"} + data, err = f.Serialize(unionStr) + require.NoError(t, err) + + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, "hello", result.Value) +} + +func TestUnionMultipleTypes(t *testing.T) { + f := NewFory() + err := f.RegisterUnionType( + reflect.TypeOf(int32(0)), + reflect.TypeOf(""), + reflect.TypeOf(float64(0)), + ) + require.NoError(t, err) + + // Test with int32 + union1 := Union{Value: int32(123)} + data, err := f.Serialize(union1) + require.NoError(t, err) + + var result Union + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, int32(123), result.Value) + + // Test with string + union2 := Union{Value: "test"} + data, err = f.Serialize(union2) + require.NoError(t, err) + + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, "test", result.Value) + + // Test with float64 + union3 := Union{Value: float64(3.14)} + data, err = f.Serialize(union3) + require.NoError(t, err) + + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.InDelta(t, 3.14, result.Value.(float64), 0.0001) +} + +func TestUnionNullValue(t *testing.T) { + f := NewFory(WithTrackRef(true)) + err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + require.NoError(t, err) + + // Test with nil value + unionNil := Union{Value: nil} + data, err := f.Serialize(unionNil) + require.NoError(t, err) + + var result Union + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Nil(t, result.Value) +} + +func TestUnionWithPointerValue(t *testing.T) { + f := NewFory() + err := f.RegisterUnionType(reflect.TypeOf((*int32)(nil)), reflect.TypeOf("")) + require.NoError(t, err) + + // Test with pointer to int32 + val := int32(42) + unionPtr := Union{Value: &val} + data, err := f.Serialize(unionPtr) + require.NoError(t, err) + + var result Union + err = f.Deserialize(data, &result) + require.NoError(t, err) + + resultPtr, ok := result.Value.(*int32) + require.True(t, ok) + require.Equal(t, int32(42), *resultPtr) +} + +func TestUnionNewHelper(t *testing.T) { + union := NewUnion(int32(42)) + require.Equal(t, int32(42), union.Value) + require.False(t, union.IsNil()) + + unionNil := NewUnion(nil) + require.True(t, unionNil.IsNil()) +} + +func TestUnionInvalidAlternative(t *testing.T) { + f := NewFory() + err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + require.NoError(t, err) + + // Try to serialize a union with an unregistered alternative type + unionBool := Union{Value: true} + _, err = f.Serialize(unionBool) + require.Error(t, err) + require.Contains(t, err.Error(), "doesn't match any alternative") +} + +func TestUnionEmptyRegistration(t *testing.T) { + f := NewFory() + err := f.RegisterUnionType() + require.Error(t, err) + require.Contains(t, err.Error(), "at least one alternative type") +} + +func TestUnionWithBytes(t *testing.T) { + f := NewFory() + err := f.RegisterUnionType(reflect.TypeOf([]byte{}), reflect.TypeOf("")) + require.NoError(t, err) + + // Test with bytes + unionBytes := Union{Value: []byte("hello")} + data, err := f.Serialize(unionBytes) + require.NoError(t, err) + + var result Union + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, []byte("hello"), result.Value) + + // Test with string + unionStr := Union{Value: "world"} + data, err = f.Serialize(unionStr) + require.NoError(t, err) + + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, "world", result.Value) +} + +func TestUnionWithRefTracking(t *testing.T) { + f := NewFory(WithTrackRef(true)) + err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + require.NoError(t, err) + + // Test with int32 value + unionInt := Union{Value: int32(42)} + data, err := f.Serialize(unionInt) + require.NoError(t, err) + + var result Union + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, int32(42), result.Value) + + // Test with string value + unionStr := Union{Value: "hello"} + data, err = f.Serialize(unionStr) + require.NoError(t, err) + + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, "hello", result.Value) +} + +func TestUnionWithInt64AndBool(t *testing.T) { + f := NewFory() + err := f.RegisterUnionType(reflect.TypeOf(int64(0)), reflect.TypeOf(false)) + require.NoError(t, err) + + // Test with int64 + union1 := Union{Value: int64(9999999999)} + data, err := f.Serialize(union1) + require.NoError(t, err) + + var result Union + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, int64(9999999999), result.Value) + + // Test with bool + union2 := Union{Value: true} + data, err = f.Serialize(union2) + require.NoError(t, err) + + err = f.Deserialize(data, &result) + require.NoError(t, err) + require.Equal(t, true, result.Value) +} From 950e1dae09e04ad43dc6189cd45d797c3150dc3f Mon Sep 17 00:00:00 2001 From: userzhy <48518279+userzhy@users.noreply.github.com> Date: Sun, 21 Dec 2025 04:09:53 +0000 Subject: [PATCH 2/2] refactor(go): use Go generics for union types Address review feedback from chaokunyang. Since Go 1.23+ is used, replace interface{}-based Union with generic Union2, Union3, Union4 types. New API: - Union2[T1, T2], Union3[T1, T2, T3], Union4[T1, T2, T3, T4] generic types - NewUnion2A/B, NewUnion3A/B/C, NewUnion4A/B/C/D constructors - Match() methods for pattern matching (type-safe callbacks) - Index(), IsFirst(), IsSecond() accessor methods - First(), Second() value extraction methods (panics on wrong access) - RegisterUnion2Type, RegisterUnion3Type, RegisterUnion4Type functions Benefits: - Compile-time type safety - Explicit variant tracking - Pattern matching support - No runtime type assertions needed --- go/fory/union.go | 458 ++++++++++++++++++++++++++++++------------ go/fory/union_test.go | 330 +++++++++++++++++++----------- 2 files changed, 551 insertions(+), 237 deletions(-) diff --git a/go/fory/union.go b/go/fory/union.go index ea3fe694dd..1d2e94d03d 100644 --- a/go/fory/union.go +++ b/go/fory/union.go @@ -22,109 +22,228 @@ import ( "reflect" ) -// Union represents a tagged union type that can hold one of several alternative types. -// It's equivalent to Rust's enum, C++'s std::variant, or Python's typing.Union. +// ============================================================================ +// Generic Union Types +// ============================================================================ + +// Union2 represents a tagged union type that can hold one of two alternative types. +// It's equivalent to Rust's enum with two variants, C++'s std::variant, +// or Python's typing.Union[T1, T2]. // -// The Value field holds the actual value, which must be one of the types specified -// when registering the Union type. +// Note: The fields are exported for serialization purposes. Use the provided +// methods (Index, First, Second, Match) for normal usage. // // Example usage: // -// // Create a union that can hold int32 or string -// union := fory.Union{Value: int32(42)} -// // or -// union := fory.Union{Value: "hello"} -type Union struct { - Value interface{} +// // Create a union that holds an int32 +// union := fory.NewUnion2A[int32, string](42) +// // or create a union that holds a string +// union := fory.NewUnion2B[int32, string]("hello") +// +// // Pattern matching +// union.Match( +// func(i int32) { fmt.Println("got int:", i) }, +// func(s string) { fmt.Println("got string:", s) }, +// ) +type Union2[T1 any, T2 any] struct { + V1 *T1 + V2 *T2 + Idx int +} + +// NewUnion2A creates a Union2 containing the first alternative type. +func NewUnion2A[T1 any, T2 any](t T1) Union2[T1, T2] { + return Union2[T1, T2]{V1: &t, Idx: 1} +} + +// NewUnion2B creates a Union2 containing the second alternative type. +func NewUnion2B[T1 any, T2 any](t T2) Union2[T1, T2] { + return Union2[T1, T2]{V2: &t, Idx: 2} +} + +// Match performs pattern matching on the union, calling the appropriate function +// based on which alternative is active. +func (u Union2[T1, T2]) Match(case1 func(T1), case2 func(T2)) { + switch u.Idx { + case 1: + case1(*u.V1) + case 2: + case2(*u.V2) + default: + panic("Union2 is uninitialized") + } +} + +// Index returns the 1-based index of the active alternative. +func (u Union2[T1, T2]) Index() int { + return u.Idx +} + +// IsFirst returns true if the first alternative is active. +func (u Union2[T1, T2]) IsFirst() bool { + return u.Idx == 1 +} + +// IsSecond returns true if the second alternative is active. +func (u Union2[T1, T2]) IsSecond() bool { + return u.Idx == 2 +} + +// First returns the first alternative value. Panics if not the active alternative. +func (u Union2[T1, T2]) First() T1 { + if u.Idx != 1 { + panic("Union2: First() called but second alternative is active") + } + return *u.V1 +} + +// Second returns the second alternative value. Panics if not the active alternative. +func (u Union2[T1, T2]) Second() T2 { + if u.Idx != 2 { + panic("Union2: Second() called but first alternative is active") + } + return *u.V2 +} + +// Union3 represents a tagged union type that can hold one of three alternative types. +type Union3[T1 any, T2 any, T3 any] struct { + V1 *T1 + V2 *T2 + V3 *T3 + Idx int } -// NewUnion creates a new Union with the given value. -func NewUnion(value interface{}) Union { - return Union{Value: value} +// NewUnion3A creates a Union3 containing the first alternative type. +func NewUnion3A[T1 any, T2 any, T3 any](t T1) Union3[T1, T2, T3] { + return Union3[T1, T2, T3]{V1: &t, Idx: 1} } -// IsNil returns true if the union holds no value. -func (u Union) IsNil() bool { - return u.Value == nil +// NewUnion3B creates a Union3 containing the second alternative type. +func NewUnion3B[T1 any, T2 any, T3 any](t T2) Union3[T1, T2, T3] { + return Union3[T1, T2, T3]{V2: &t, Idx: 2} } -// unionSerializer serializes Union types. +// NewUnion3C creates a Union3 containing the third alternative type. +func NewUnion3C[T1 any, T2 any, T3 any](t T3) Union3[T1, T2, T3] { + return Union3[T1, T2, T3]{V3: &t, Idx: 3} +} + +// Match performs pattern matching on the union. +func (u Union3[T1, T2, T3]) Match(f1 func(T1), f2 func(T2), f3 func(T3)) { + switch u.Idx { + case 1: + f1(*u.V1) + case 2: + f2(*u.V2) + case 3: + f3(*u.V3) + default: + panic("Union3 is uninitialized") + } +} + +// Index returns the 1-based index of the active alternative. +func (u Union3[T1, T2, T3]) Index() int { + return u.Idx +} + +// Union4 represents a tagged union type that can hold one of four alternative types. +type Union4[T1 any, T2 any, T3 any, T4 any] struct { + V1 *T1 + V2 *T2 + V3 *T3 + V4 *T4 + Idx int +} + +// NewUnion4A creates a Union4 containing the first alternative type. +func NewUnion4A[T1 any, T2 any, T3 any, T4 any](t T1) Union4[T1, T2, T3, T4] { + return Union4[T1, T2, T3, T4]{V1: &t, Idx: 1} +} + +// NewUnion4B creates a Union4 containing the second alternative type. +func NewUnion4B[T1 any, T2 any, T3 any, T4 any](t T2) Union4[T1, T2, T3, T4] { + return Union4[T1, T2, T3, T4]{V2: &t, Idx: 2} +} + +// NewUnion4C creates a Union4 containing the third alternative type. +func NewUnion4C[T1 any, T2 any, T3 any, T4 any](t T3) Union4[T1, T2, T3, T4] { + return Union4[T1, T2, T3, T4]{V3: &t, Idx: 3} +} + +// NewUnion4D creates a Union4 containing the fourth alternative type. +func NewUnion4D[T1 any, T2 any, T3 any, T4 any](t T4) Union4[T1, T2, T3, T4] { + return Union4[T1, T2, T3, T4]{V4: &t, Idx: 4} +} + +// Match performs pattern matching on the union. +func (u Union4[T1, T2, T3, T4]) Match(f1 func(T1), f2 func(T2), f3 func(T3), f4 func(T4)) { + switch u.Idx { + case 1: + f1(*u.V1) + case 2: + f2(*u.V2) + case 3: + f3(*u.V3) + case 4: + f4(*u.V4) + default: + panic("Union4 is uninitialized") + } +} + +// Index returns the 1-based index of the active alternative. +func (u Union4[T1, T2, T3, T4]) Index() int { + return u.Idx +} + +// ============================================================================ +// Union Serializer +// ============================================================================ + +// unionSerializer serializes generic Union types. // // Serialization format: // 1. Write variant index (varuint32) - identifies which alternative type is active // 2. In xlang mode, write type info for the active alternative // 3. Write the value data using the alternative's serializer type unionSerializer struct { - type_ reflect.Type - alternativeTypes []reflect.Type - typeResolver *TypeResolver - alternativeTypeInfo []*TypeInfo + type_ reflect.Type + alternativeTypes []reflect.Type + typeResolver *TypeResolver } // newUnionSerializer creates a new serializer for Union types with the specified alternatives. -// The alternativeTypes slice defines the allowed types in order - the index is used as the variant index. -func newUnionSerializer(typeResolver *TypeResolver, alternativeTypes []reflect.Type) *unionSerializer { - typeInfos := make([]*TypeInfo, len(alternativeTypes)) +func newUnionSerializer(type_ reflect.Type, typeResolver *TypeResolver, alternativeTypes []reflect.Type) *unionSerializer { return &unionSerializer{ - type_: reflect.TypeOf(Union{}), - alternativeTypes: alternativeTypes, - typeResolver: typeResolver, - alternativeTypeInfo: typeInfos, + type_: type_, + alternativeTypes: alternativeTypes, + typeResolver: typeResolver, } } -// findAlternativeIndex finds the index of the type that matches the given value. -// Returns -1 if no match is found. -func (s *unionSerializer) findAlternativeIndex(value reflect.Value) int { - if !value.IsValid() || (value.Kind() == reflect.Interface && value.IsNil()) { - return -1 - } - - valueType := value.Type() - if valueType.Kind() == reflect.Interface { - valueType = value.Elem().Type() - } +func (s *unionSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, value reflect.Value) { + buf := ctx.Buffer() - for i, altType := range s.alternativeTypes { - if valueType == altType { - return i - } - // Also check if the value is assignable to the alternative type - if valueType.AssignableTo(altType) { - return i - } - // For pointer types, check the elem type - if valueType.Kind() == reflect.Ptr && altType.Kind() == reflect.Ptr { - if valueType.Elem() == altType.Elem() { - return i - } - } + // Handle nil pointer + if value.Kind() == reflect.Ptr && value.IsNil() { + buf.WriteInt8(NullFlag) + return } - return -1 -} -func (s *unionSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, value reflect.Value) error { - buf := ctx.Buffer() - - // Get the Union value - var union Union + // Get the actual struct value if value.Kind() == reflect.Ptr { - if value.IsNil() { - buf.WriteInt8(NullFlag) - return nil - } - union = value.Elem().Interface().(Union) - } else { - union = value.Interface().(Union) + value = value.Elem() } - // Handle null union value - if union.Value == nil { + // Get the index field to check if initialized + indexField := value.FieldByName("Idx") + if !indexField.IsValid() || indexField.Int() == 0 { switch refMode { case RefModeTracking, RefModeNullOnly: buf.WriteInt8(NullFlag) } - return nil + return } // Write ref flag for non-null @@ -132,10 +251,11 @@ func (s *unionSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bo case RefModeTracking: refWritten, err := ctx.RefResolver().WriteRefOrNull(buf, value) if err != nil { - return err + ctx.SetError(FromError(err)) + return } if refWritten { - return nil + return } case RefModeNullOnly: buf.WriteInt8(NotNullValueFlag) @@ -146,61 +266,75 @@ func (s *unionSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bo buf.WriteVaruint32Small7(uint32(UNION)) } - return s.WriteData(ctx, value) + s.WriteData(ctx, value) } -func (s *unionSerializer) WriteData(ctx *WriteContext, value reflect.Value) error { +func (s *unionSerializer) WriteData(ctx *WriteContext, value reflect.Value) { buf := ctx.Buffer() - // Get the Union value - var union Union + // Get the actual struct value if value.Kind() == reflect.Ptr { - union = value.Elem().Interface().(Union) - } else { - union = value.Interface().(Union) + value = value.Elem() } - // Find which alternative type matches the value - innerValue := reflect.ValueOf(union.Value) - activeIndex := s.findAlternativeIndex(innerValue) + // Get the active index (1-based in the struct) + indexField := value.FieldByName("Idx") + activeIndex := int(indexField.Int()) - 1 // Convert to 0-based - if activeIndex < 0 { - return fmt.Errorf("union value type %T doesn't match any alternative in %v", union.Value, s.alternativeTypes) + if activeIndex < 0 || activeIndex >= len(s.alternativeTypes) { + ctx.SetError(SerializationErrorf("union index out of bounds: %d", activeIndex+1)) + return } - // Write the active variant index + // Write the active variant index (0-based for protocol) buf.WriteVaruint32(uint32(activeIndex)) + // Get the value pointer field (V1, V2, V3, V4) + fieldName := fmt.Sprintf("V%d", activeIndex+1) + valueField := value.FieldByName(fieldName) + if !valueField.IsValid() || valueField.IsNil() { + ctx.SetError(SerializationErrorf("union value field %s is nil", fieldName)) + return + } + + // Get the actual value (dereference the pointer) + innerValue := valueField.Elem() + // Get the serializer for the active alternative altType := s.alternativeTypes[activeIndex] serializer, err := ctx.TypeResolver().getSerializerByType(altType, false) if err != nil { - return fmt.Errorf("no serializer for union alternative type %v: %w", altType, err) + ctx.SetError(FromError(fmt.Errorf("no serializer for union alternative type %v: %w", altType, err))) + return } // In xlang mode, write type info for the alternative if ctx.TypeResolver().isXlang { typeInfo, err := ctx.TypeResolver().getTypeInfo(innerValue, true) if err != nil { - return err + ctx.SetError(FromError(err)) + return } if err := ctx.TypeResolver().WriteTypeInfo(buf, typeInfo); err != nil { - return err + ctx.SetError(FromError(err)) + return } } // Write the value data - return serializer.WriteData(ctx, innerValue) + serializer.WriteData(ctx, innerValue) } -func (s *unionSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, value reflect.Value) error { +func (s *unionSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, value reflect.Value) { buf := ctx.Buffer() + ctxErr := ctx.Err() switch refMode { case RefModeTracking: refID, err := ctx.RefResolver().TryPreserveRefId(buf) if err != nil { - return err + ctx.SetError(FromError(err)) + return } if int8(refID) < NotNullValueFlag { obj := ctx.RefResolver().GetReadObject(refID) @@ -211,34 +345,40 @@ func (s *unionSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, value.Set(obj) } } - return nil + return } case RefModeNullOnly: - flag := buf.ReadInt8() + flag := buf.ReadInt8(ctxErr) if flag == NullFlag { - return nil + return } } if readType { - typeId := buf.ReadVaruint32Small7() + typeId := buf.ReadVaruint32Small7(ctxErr) if TypeId(typeId) != UNION { - return fmt.Errorf("expected UNION type id %d, got %d", UNION, typeId) + ctx.SetError(DeserializationErrorf("expected UNION type id %d, got %d", UNION, typeId)) + return } } - return s.ReadData(ctx, s.type_, value) + s.ReadData(ctx, s.type_, value) } -func (s *unionSerializer) ReadData(ctx *ReadContext, type_ reflect.Type, value reflect.Value) error { +func (s *unionSerializer) ReadData(ctx *ReadContext, type_ reflect.Type, value reflect.Value) { buf := ctx.Buffer() + ctxErr := ctx.Err() - // Read the stored variant index - storedIndex := buf.ReadVaruint32() + // Read the stored variant index (0-based) + storedIndex := int(buf.ReadVaruint32(ctxErr)) + if ctx.HasError() { + return + } // Validate index is within bounds - if int(storedIndex) >= len(s.alternativeTypes) { - return fmt.Errorf("union index out of bounds: %d (max: %d)", storedIndex, len(s.alternativeTypes)-1) + if storedIndex < 0 || storedIndex >= len(s.alternativeTypes) { + ctx.SetError(DeserializationErrorf("union index out of bounds: %d (max: %d)", storedIndex, len(s.alternativeTypes)-1)) + return } // Get the alternative type @@ -247,16 +387,17 @@ func (s *unionSerializer) ReadData(ctx *ReadContext, type_ reflect.Type, value r // Get serializer for this alternative serializer, err := ctx.TypeResolver().getSerializerByType(altType, false) if err != nil { - return fmt.Errorf("no serializer for union alternative type %v: %w", altType, err) + ctx.SetError(FromError(fmt.Errorf("no serializer for union alternative type %v: %w", altType, err))) + return } // In xlang mode, read type info for the alternative if ctx.TypeResolver().isXlang { - // Read the type info - we need to pass a value for the ReadTypeInfo function dummyValue := reflect.New(altType).Elem() _, err := ctx.TypeResolver().ReadTypeInfo(buf, dummyValue) if err != nil { - return err + ctx.SetError(FromError(err)) + return } } @@ -264,43 +405,62 @@ func (s *unionSerializer) ReadData(ctx *ReadContext, type_ reflect.Type, value r altValue := reflect.New(altType).Elem() // Read the value data - if err := serializer.ReadData(ctx, altType, altValue); err != nil { - return err + serializer.ReadData(ctx, altType, altValue) + if ctx.HasError() { + return } - // Set the union value - union := Union{Value: altValue.Interface()} - if value.Kind() == reflect.Ptr { - value.Elem().Set(reflect.ValueOf(union)) - } else { - value.Set(reflect.ValueOf(union)) + // Get the target struct value + targetValue := value + if targetValue.Kind() == reflect.Ptr { + if targetValue.IsNil() { + targetValue.Set(reflect.New(targetValue.Type().Elem())) + } + targetValue = targetValue.Elem() } - return nil + // Set the index field (1-based) + indexField := targetValue.FieldByName("Idx") + indexField.SetInt(int64(storedIndex + 1)) + + // Set the value pointer field + fieldName := fmt.Sprintf("V%d", storedIndex+1) + valueField := targetValue.FieldByName(fieldName) + // Create a pointer to the value and set it + ptrValue := reflect.New(altType) + ptrValue.Elem().Set(altValue) + valueField.Set(ptrValue) } -func (s *unionSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) error { - return s.Read(ctx, refMode, false, value) +func (s *unionSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, value) } -// RegisterUnionType registers a Union type with the specified alternative types. -// The alternative types are the types that the union can hold. +// ============================================================================ +// Registration Functions +// ============================================================================ + +// RegisterUnion2Type registers a Union2[T1, T2] type for serialization. // Returns an error if registration fails. // // Example: // // f := fory.NewFory() -// err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) +// err := fory.RegisterUnion2Type[int32, string](f) // if err != nil { // panic(err) // } -func (f *Fory) RegisterUnionType(alternativeTypes ...reflect.Type) error { - if len(alternativeTypes) == 0 { - return fmt.Errorf("union must have at least one alternative type") +func RegisterUnion2Type[T1 any, T2 any](f *Fory) error { + var zero1 T1 + var zero2 T2 + + unionType := reflect.TypeOf(Union2[T1, T2]{}) + alternativeTypes := []reflect.Type{ + reflect.TypeOf(zero1), + reflect.TypeOf(zero2), } - unionType := reflect.TypeOf(Union{}) - serializer := newUnionSerializer(f.typeResolver, alternativeTypes) + serializer := newUnionSerializer(unionType, f.typeResolver, alternativeTypes) // Register the union type with the serializer f.typeResolver.typeToSerializers[unionType] = serializer @@ -313,3 +473,53 @@ func (f *Fory) RegisterUnionType(alternativeTypes ...reflect.Type) error { return nil } + +// RegisterUnion3Type registers a Union3[T1, T2, T3] type for serialization. +func RegisterUnion3Type[T1 any, T2 any, T3 any](f *Fory) error { + var zero1 T1 + var zero2 T2 + var zero3 T3 + + unionType := reflect.TypeOf(Union3[T1, T2, T3]{}) + alternativeTypes := []reflect.Type{ + reflect.TypeOf(zero1), + reflect.TypeOf(zero2), + reflect.TypeOf(zero3), + } + + serializer := newUnionSerializer(unionType, f.typeResolver, alternativeTypes) + + f.typeResolver.typeToSerializers[unionType] = serializer + ptrUnionType := reflect.PtrTo(unionType) + f.typeResolver.typeToSerializers[ptrUnionType] = &ptrToValueSerializer{ + valueSerializer: serializer, + } + + return nil +} + +// RegisterUnion4Type registers a Union4[T1, T2, T3, T4] type for serialization. +func RegisterUnion4Type[T1 any, T2 any, T3 any, T4 any](f *Fory) error { + var zero1 T1 + var zero2 T2 + var zero3 T3 + var zero4 T4 + + unionType := reflect.TypeOf(Union4[T1, T2, T3, T4]{}) + alternativeTypes := []reflect.Type{ + reflect.TypeOf(zero1), + reflect.TypeOf(zero2), + reflect.TypeOf(zero3), + reflect.TypeOf(zero4), + } + + serializer := newUnionSerializer(unionType, f.typeResolver, alternativeTypes) + + f.typeResolver.typeToSerializers[unionType] = serializer + ptrUnionType := reflect.PtrTo(unionType) + f.typeResolver.typeToSerializers[ptrUnionType] = &ptrToValueSerializer{ + valueSerializer: serializer, + } + + return nil +} diff --git a/go/fory/union_test.go b/go/fory/union_test.go index a858f3626b..9a1b758f2b 100644 --- a/go/fory/union_test.go +++ b/go/fory/union_test.go @@ -18,210 +18,314 @@ package fory import ( - "reflect" "testing" "github.com/stretchr/testify/require" ) -func TestUnionBasicTypes(t *testing.T) { +// ============================================================================ +// Union2 Tests +// ============================================================================ + +func TestUnion2BasicTypes(t *testing.T) { f := NewFory() - err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + err := RegisterUnion2Type[int32, string](f) require.NoError(t, err) - // Test with int32 value - unionInt := Union{Value: int32(42)} - data, err := f.Serialize(unionInt) + // Test with int32 value (first alternative) + union1 := NewUnion2A[int32, string](42) + data, err := f.Serialize(union1) require.NoError(t, err) - var result Union - err = f.Deserialize(data, &result) + var result1 Union2[int32, string] + err = f.Deserialize(data, &result1) require.NoError(t, err) - require.Equal(t, int32(42), result.Value) + require.Equal(t, 1, result1.Index()) + require.True(t, result1.IsFirst()) + require.Equal(t, int32(42), result1.First()) - // Test with string value - unionStr := Union{Value: "hello"} - data, err = f.Serialize(unionStr) + // Test with string value (second alternative) + union2 := NewUnion2B[int32, string]("hello") + data, err = f.Serialize(union2) require.NoError(t, err) - err = f.Deserialize(data, &result) + var result2 Union2[int32, string] + err = f.Deserialize(data, &result2) require.NoError(t, err) - require.Equal(t, "hello", result.Value) + require.Equal(t, 2, result2.Index()) + require.True(t, result2.IsSecond()) + require.Equal(t, "hello", result2.Second()) } -func TestUnionMultipleTypes(t *testing.T) { - f := NewFory() - err := f.RegisterUnionType( - reflect.TypeOf(int32(0)), - reflect.TypeOf(""), - reflect.TypeOf(float64(0)), +func TestUnion2Match(t *testing.T) { + union1 := NewUnion2A[int32, string](42) + var matchedInt int32 + var matchedStr string + + union1.Match( + func(i int32) { matchedInt = i }, + func(s string) { matchedStr = s }, + ) + require.Equal(t, int32(42), matchedInt) + require.Empty(t, matchedStr) + + union2 := NewUnion2B[int32, string]("hello") + matchedInt = 0 + matchedStr = "" + + union2.Match( + func(i int32) { matchedInt = i }, + func(s string) { matchedStr = s }, ) + require.Equal(t, int32(0), matchedInt) + require.Equal(t, "hello", matchedStr) +} + +func TestUnion2WithFloats(t *testing.T) { + f := NewFory() + err := RegisterUnion2Type[float32, float64](f) require.NoError(t, err) - // Test with int32 - union1 := Union{Value: int32(123)} + // Test with float32 + union1 := NewUnion2A[float32, float64](float32(3.14)) data, err := f.Serialize(union1) require.NoError(t, err) - var result Union - err = f.Deserialize(data, &result) + var result1 Union2[float32, float64] + err = f.Deserialize(data, &result1) require.NoError(t, err) - require.Equal(t, int32(123), result.Value) + require.True(t, result1.IsFirst()) + require.InDelta(t, float32(3.14), result1.First(), 0.0001) - // Test with string - union2 := Union{Value: "test"} + // Test with float64 + union2 := NewUnion2B[float32, float64](float64(2.71828)) data, err = f.Serialize(union2) require.NoError(t, err) - err = f.Deserialize(data, &result) + var result2 Union2[float32, float64] + err = f.Deserialize(data, &result2) require.NoError(t, err) - require.Equal(t, "test", result.Value) + require.True(t, result2.IsSecond()) + require.InDelta(t, float64(2.71828), result2.Second(), 0.0001) +} - // Test with float64 - union3 := Union{Value: float64(3.14)} - data, err = f.Serialize(union3) +func TestUnion2WithBoolAndInt64(t *testing.T) { + f := NewFory() + err := RegisterUnion2Type[bool, int64](f) require.NoError(t, err) - err = f.Deserialize(data, &result) + // Test with bool + union1 := NewUnion2A[bool, int64](true) + data, err := f.Serialize(union1) require.NoError(t, err) - require.InDelta(t, 3.14, result.Value.(float64), 0.0001) -} -func TestUnionNullValue(t *testing.T) { - f := NewFory(WithTrackRef(true)) - err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + var result1 Union2[bool, int64] + err = f.Deserialize(data, &result1) require.NoError(t, err) + require.True(t, result1.IsFirst()) + require.True(t, result1.First()) - // Test with nil value - unionNil := Union{Value: nil} - data, err := f.Serialize(unionNil) + // Test with int64 + union2 := NewUnion2B[bool, int64](int64(9999999999)) + data, err = f.Serialize(union2) require.NoError(t, err) - var result Union - err = f.Deserialize(data, &result) + var result2 Union2[bool, int64] + err = f.Deserialize(data, &result2) require.NoError(t, err) - require.Nil(t, result.Value) + require.True(t, result2.IsSecond()) + require.Equal(t, int64(9999999999), result2.Second()) } -func TestUnionWithPointerValue(t *testing.T) { +// ============================================================================ +// Union3 Tests +// ============================================================================ + +func TestUnion3BasicTypes(t *testing.T) { f := NewFory() - err := f.RegisterUnionType(reflect.TypeOf((*int32)(nil)), reflect.TypeOf("")) + err := RegisterUnion3Type[int32, string, float64](f) require.NoError(t, err) - // Test with pointer to int32 - val := int32(42) - unionPtr := Union{Value: &val} - data, err := f.Serialize(unionPtr) + // Test with int32 value (first alternative) + union1 := NewUnion3A[int32, string, float64](123) + data, err := f.Serialize(union1) require.NoError(t, err) - var result Union - err = f.Deserialize(data, &result) + var result1 Union3[int32, string, float64] + err = f.Deserialize(data, &result1) require.NoError(t, err) + require.Equal(t, 1, result1.Index()) - resultPtr, ok := result.Value.(*int32) - require.True(t, ok) - require.Equal(t, int32(42), *resultPtr) -} - -func TestUnionNewHelper(t *testing.T) { - union := NewUnion(int32(42)) - require.Equal(t, int32(42), union.Value) - require.False(t, union.IsNil()) + // Test with string value (second alternative) + union2 := NewUnion3B[int32, string, float64]("test") + data, err = f.Serialize(union2) + require.NoError(t, err) - unionNil := NewUnion(nil) - require.True(t, unionNil.IsNil()) -} + var result2 Union3[int32, string, float64] + err = f.Deserialize(data, &result2) + require.NoError(t, err) + require.Equal(t, 2, result2.Index()) -func TestUnionInvalidAlternative(t *testing.T) { - f := NewFory() - err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + // Test with float64 value (third alternative) + union3 := NewUnion3C[int32, string, float64](3.14) + data, err = f.Serialize(union3) require.NoError(t, err) - // Try to serialize a union with an unregistered alternative type - unionBool := Union{Value: true} - _, err = f.Serialize(unionBool) - require.Error(t, err) - require.Contains(t, err.Error(), "doesn't match any alternative") + var result3 Union3[int32, string, float64] + err = f.Deserialize(data, &result3) + require.NoError(t, err) + require.Equal(t, 3, result3.Index()) } -func TestUnionEmptyRegistration(t *testing.T) { - f := NewFory() - err := f.RegisterUnionType() - require.Error(t, err) - require.Contains(t, err.Error(), "at least one alternative type") +func TestUnion3Match(t *testing.T) { + union := NewUnion3C[int32, string, float64](2.5) + + var matchedInt int32 + var matchedStr string + var matchedFloat float64 + + union.Match( + func(i int32) { matchedInt = i }, + func(s string) { matchedStr = s }, + func(f float64) { matchedFloat = f }, + ) + + require.Equal(t, int32(0), matchedInt) + require.Empty(t, matchedStr) + require.InDelta(t, 2.5, matchedFloat, 0.0001) } -func TestUnionWithBytes(t *testing.T) { +// ============================================================================ +// Union4 Tests +// ============================================================================ + +func TestUnion4BasicTypes(t *testing.T) { f := NewFory() - err := f.RegisterUnionType(reflect.TypeOf([]byte{}), reflect.TypeOf("")) + err := RegisterUnion4Type[int32, string, float64, bool](f) require.NoError(t, err) - // Test with bytes - unionBytes := Union{Value: []byte("hello")} - data, err := f.Serialize(unionBytes) + // Test with bool value (fourth alternative) + union := NewUnion4D[int32, string, float64, bool](true) + data, err := f.Serialize(union) require.NoError(t, err) - var result Union + var result Union4[int32, string, float64, bool] err = f.Deserialize(data, &result) require.NoError(t, err) - require.Equal(t, []byte("hello"), result.Value) + require.Equal(t, 4, result.Index()) +} - // Test with string - unionStr := Union{Value: "world"} - data, err = f.Serialize(unionStr) - require.NoError(t, err) +func TestUnion4Match(t *testing.T) { + union := NewUnion4B[int32, string, float64, bool]("world") - err = f.Deserialize(data, &result) - require.NoError(t, err) - require.Equal(t, "world", result.Value) + var matchedInt int32 + var matchedStr string + var matchedFloat float64 + var matchedBool bool + + union.Match( + func(i int32) { matchedInt = i }, + func(s string) { matchedStr = s }, + func(f float64) { matchedFloat = f }, + func(b bool) { matchedBool = b }, + ) + + require.Equal(t, int32(0), matchedInt) + require.Equal(t, "world", matchedStr) + require.Equal(t, float64(0), matchedFloat) + require.False(t, matchedBool) } -func TestUnionWithRefTracking(t *testing.T) { +// ============================================================================ +// Edge Cases +// ============================================================================ + +func TestUnion2WithRefTracking(t *testing.T) { f := NewFory(WithTrackRef(true)) - err := f.RegisterUnionType(reflect.TypeOf(int32(0)), reflect.TypeOf("")) + err := RegisterUnion2Type[int32, string](f) require.NoError(t, err) // Test with int32 value - unionInt := Union{Value: int32(42)} - data, err := f.Serialize(unionInt) + union := NewUnion2A[int32, string](42) + data, err := f.Serialize(union) require.NoError(t, err) - var result Union + var result Union2[int32, string] err = f.Deserialize(data, &result) require.NoError(t, err) - require.Equal(t, int32(42), result.Value) + require.Equal(t, int32(42), result.First()) +} + +func TestUnion2PanicOnWrongAccess(t *testing.T) { + union := NewUnion2A[int32, string](42) + + // Accessing First() should work + require.NotPanics(t, func() { + _ = union.First() + }) + + // Accessing Second() should panic + require.Panics(t, func() { + _ = union.Second() + }) +} + +func TestUnion2MultipleRegistrations(t *testing.T) { + f := NewFory() + + // Register first Union2 type + err := RegisterUnion2Type[int32, string](f) + require.NoError(t, err) - // Test with string value - unionStr := Union{Value: "hello"} - data, err = f.Serialize(unionStr) + // Register second different Union2 type + err = RegisterUnion2Type[bool, float64](f) require.NoError(t, err) - err = f.Deserialize(data, &result) + // Serialize and deserialize first type + union1 := NewUnion2A[int32, string](42) + data1, err := f.Serialize(union1) + require.NoError(t, err) + + var result1 Union2[int32, string] + err = f.Deserialize(data1, &result1) require.NoError(t, err) - require.Equal(t, "hello", result.Value) + require.Equal(t, int32(42), result1.First()) + + // Serialize and deserialize second type + union2 := NewUnion2A[bool, float64](true) + data2, err := f.Serialize(union2) + require.NoError(t, err) + + var result2 Union2[bool, float64] + err = f.Deserialize(data2, &result2) + require.NoError(t, err) + require.True(t, result2.First()) } -func TestUnionWithInt64AndBool(t *testing.T) { +func TestUnion2Bytes(t *testing.T) { f := NewFory() - err := f.RegisterUnionType(reflect.TypeOf(int64(0)), reflect.TypeOf(false)) + err := RegisterUnion2Type[[]byte, string](f) require.NoError(t, err) - // Test with int64 - union1 := Union{Value: int64(9999999999)} + // Test with bytes + union1 := NewUnion2A[[]byte, string]([]byte("hello bytes")) data, err := f.Serialize(union1) require.NoError(t, err) - var result Union - err = f.Deserialize(data, &result) + var result1 Union2[[]byte, string] + err = f.Deserialize(data, &result1) require.NoError(t, err) - require.Equal(t, int64(9999999999), result.Value) + require.True(t, result1.IsFirst()) + require.Equal(t, []byte("hello bytes"), result1.First()) - // Test with bool - union2 := Union{Value: true} + // Test with string + union2 := NewUnion2B[[]byte, string]("hello string") data, err = f.Serialize(union2) require.NoError(t, err) - err = f.Deserialize(data, &result) + var result2 Union2[[]byte, string] + err = f.Deserialize(data, &result2) require.NoError(t, err) - require.Equal(t, true, result.Value) + require.True(t, result2.IsSecond()) + require.Equal(t, "hello string", result2.Second()) }