diff --git a/tx_binary.go b/tx_binary.go index d4a3ba4..46b8dcd 100644 --- a/tx_binary.go +++ b/tx_binary.go @@ -1191,7 +1191,7 @@ func (enc *txBinaryEncoder) writeTxBinaryBody(tx *TxBinary, enumTable *txBinaryE func (enc *txBinaryEncoder) writePlatformEntries(entries []PlatformBinary, enumTable *txBinaryEnumTable) error { enc.writeUint32(uint32(len(entries))) for i, entry := range entries { - enumID, err := enumTable.platforms.id(entry.Platform) + enumID, err := enumTable.platforms.idOrFallback(entry.Platform, PlatformNone) if err != nil { return fmt.Errorf("platform[%d]: %w", i, err) } @@ -1204,7 +1204,7 @@ func (enc *txBinaryEncoder) writePlatformEntries(entries []PlatformBinary, enumT func (enc *txBinaryEncoder) writeMevAgentEntries(entries []MevAgentBinary, enumTable *txBinaryEnumTable) error { enc.writeUint32(uint32(len(entries))) for i, entry := range entries { - enumID, err := enumTable.mevAgents.id(entry.MevAgent) + enumID, err := enumTable.mevAgents.idOrFallback(entry.MevAgent, MevAgentUnknown) if err != nil { return fmt.Errorf("mev_agent[%d]: %w", i, err) } @@ -1592,7 +1592,7 @@ func txBinaryReadPlatformEntries(dec txBinaryBodyReader, enumTable *txBinaryEnum if err != nil { return nil, err } - platform, err := enumTable.platforms.value(enumID) + platform, err := enumTable.platforms.valueOrFallback(enumID, PlatformNone) if err != nil { return nil, fmt.Errorf("platform[%d]: %w", i, err) } @@ -1619,7 +1619,7 @@ func txBinaryReadMevAgentEntries(dec txBinaryBodyReader, enumTable *txBinaryEnum if err != nil { return nil, err } - mevAgent, err := enumTable.mevAgents.value(enumID) + mevAgent, err := enumTable.mevAgents.valueOrFallback(enumID, MevAgentUnknown) if err != nil { return nil, fmt.Errorf("mev_agent[%d]: %w", i, err) } @@ -2144,6 +2144,7 @@ var txBinaryEnumTables = map[uint16]*txBinaryEnumTable{ MevAgentAllenhark, MevAgentRaiden, MevAgentZan, + MevAgentTunneling, }, ), } @@ -2199,9 +2200,30 @@ func (set txBinaryEnumSet) id(value string) (uint16, error) { return id, nil } +func (set txBinaryEnumSet) idOrFallback(value string, fallback string) (uint16, error) { + if id, ok := set.ids[value]; ok { + return id, nil + } + id, ok := set.ids[fallback] + if !ok { + return 0, fmt.Errorf("unsupported %s fallback enum value %q for versioned tx binary", set.name, fallback) + } + return id, nil +} + func (set txBinaryEnumSet) value(id uint16) (string, error) { if int(id) >= len(set.values) { return "", fmt.Errorf("unknown %s enum id %d", set.name, id) } return set.values[id], nil } + +func (set txBinaryEnumSet) valueOrFallback(id uint16, fallback string) (string, error) { + if int(id) < len(set.values) { + return set.values[id], nil + } + if _, ok := set.ids[fallback]; !ok { + return "", fmt.Errorf("unsupported %s fallback enum value %q for versioned tx binary", set.name, fallback) + } + return fallback, nil +} diff --git a/tx_binary_test.go b/tx_binary_test.go index c1379fe..2b7ea4b 100644 --- a/tx_binary_test.go +++ b/tx_binary_test.go @@ -45,6 +45,10 @@ func TestTxBinaryRoundTrip(t *testing.T) { MevAgent: MevAgentZan, MevAgentFee: decimal.RequireFromString("0.040000000"), }, + MevAgentTunneling: { + MevAgent: MevAgentTunneling, + MevAgentFee: decimal.RequireFromString("0.050000000"), + }, }, Swaps: []Swap{ { @@ -153,6 +157,9 @@ func TestTxBinaryRoundTrip(t *testing.T) { if !decoded.MevAgent[MevAgentZan].MevAgentFee.Equal(original.MevAgent[MevAgentZan].MevAgentFee) { t.Fatalf("Zan MevAgent fee mismatch") } + if !decoded.MevAgent[MevAgentTunneling].MevAgentFee.Equal(original.MevAgent[MevAgentTunneling].MevAgentFee) { + t.Fatalf("Tunneling MevAgent fee mismatch") + } if len(decoded.Swaps) != 1 { t.Fatalf("Swaps len = %d, want 1", len(decoded.Swaps)) } @@ -232,6 +239,87 @@ func TestTxBinaryRejectsUnknownProgramEnum(t *testing.T) { } } +func TestTxBinaryLabelEnumsFallbackToUnknown(t *testing.T) { + original := &Tx{ + Signer: solana.WrappedSol, + Platform: map[string]platformInfo{ + "future-platform": { + Platform: "future-platform", + PlatformFee: decimal.RequireFromString("0.010000000"), + }, + }, + MevAgent: map[string]mevInfo{ + "future-mev-agent": { + MevAgent: "future-mev-agent", + MevAgentFee: decimal.RequireFromString("0.020000000"), + }, + }, + } + + encoded, err := EncodeTxBinary(original) + if err != nil { + t.Fatalf("EncodeTxBinary() error = %v", err) + } + decoded, err := DecodeTxBinary(encoded) + if err != nil { + t.Fatalf("DecodeTxBinary() error = %v", err) + } + + if len(decoded.Platform) != 1 { + t.Fatalf("Platform len = %d, want 1", len(decoded.Platform)) + } + if _, exists := decoded.Platform["future-platform"]; exists { + t.Fatalf("future platform was preserved, want fallback") + } + if !decoded.Platform[PlatformNone].PlatformFee.Equal(original.Platform["future-platform"].PlatformFee) { + t.Fatalf("PlatformNone fee = %s, want %s", decoded.Platform[PlatformNone].PlatformFee, original.Platform["future-platform"].PlatformFee) + } + + if len(decoded.MevAgent) != 1 { + t.Fatalf("MevAgent len = %d, want 1", len(decoded.MevAgent)) + } + if _, exists := decoded.MevAgent["future-mev-agent"]; exists { + t.Fatalf("future mev agent was preserved, want fallback") + } + if !decoded.MevAgent[MevAgentUnknown].MevAgentFee.Equal(original.MevAgent["future-mev-agent"].MevAgentFee) { + t.Fatalf("MevAgentUnknown fee = %s, want %s", decoded.MevAgent[MevAgentUnknown].MevAgentFee, original.MevAgent["future-mev-agent"].MevAgentFee) + } +} + +func TestTxBinaryReadLabelEnumUnknownIDsFallback(t *testing.T) { + enumTable := txBinaryEnumTables[txBinaryEnumVersionV1] + + platformFee := uint64(123) + platformEnc := txBinaryEncoder{} + platformEnc.writeUint32(1) + platformEnc.writeUint16(uint16(len(enumTable.platforms.values) + 10)) + platformEnc.writeUint64(platformFee) + + platformDec := txBinaryDecoder{reader: bytes.NewReader(platformEnc.bytes())} + platforms, err := txBinaryReadPlatformEntries(&platformDec, enumTable) + if err != nil { + t.Fatalf("txBinaryReadPlatformEntries() error = %v", err) + } + if len(platforms) != 1 || platforms[0].Platform != PlatformNone || platforms[0].PlatformFee != platformFee { + t.Fatalf("platform fallback = %+v, want %s/%d", platforms, PlatformNone, platformFee) + } + + mevFee := uint64(456) + mevEnc := txBinaryEncoder{} + mevEnc.writeUint32(1) + mevEnc.writeUint16(uint16(len(enumTable.mevAgents.values) + 10)) + mevEnc.writeUint64(mevFee) + + mevDec := txBinaryDecoder{reader: bytes.NewReader(mevEnc.bytes())} + mevAgents, err := txBinaryReadMevAgentEntries(&mevDec, enumTable) + if err != nil { + t.Fatalf("txBinaryReadMevAgentEntries() error = %v", err) + } + if len(mevAgents) != 1 || mevAgents[0].MevAgent != MevAgentUnknown || mevAgents[0].MevAgentFee != mevFee { + t.Fatalf("mev agent fallback = %+v, want %s/%d", mevAgents, MevAgentUnknown, mevFee) + } +} + func TestTxBinaryAcceptsKnownEventEnums(t *testing.T) { events := []string{ TxEventAddLP,