From d2879efcc6886b04adab436ac995b703d9e70472 Mon Sep 17 00:00:00 2001 From: thloyi Date: Thu, 16 Apr 2026 16:40:21 +0800 Subject: [PATCH] batch encode --- cmd/measure_tx_binary_block/main.go | 134 ++ enum.go | 12 +- tx_binary.go | 2100 +++++++++++++++++++++++++++ tx_binary_realdata_test.go | 146 ++ tx_binary_test.go | 627 ++++++++ 5 files changed, 3014 insertions(+), 5 deletions(-) create mode 100644 cmd/measure_tx_binary_block/main.go create mode 100644 tx_binary.go create mode 100644 tx_binary_realdata_test.go create mode 100644 tx_binary_test.go diff --git a/cmd/measure_tx_binary_block/main.go b/cmd/measure_tx_binary_block/main.go new file mode 100644 index 0000000..87920ad --- /dev/null +++ b/cmd/measure_tx_binary_block/main.go @@ -0,0 +1,134 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "os" + + pump_parser "github.com/thloyi/pump-parser" +) + +type blockResponse struct { + Result blockResult `json:"result"` +} + +type blockResult struct { + BlockTime *int64 `json:"blockTime"` + Transactions []pump_parser.RawTx `json:"transactions"` +} + +func main() { + var ( + filePath = flag.String("file", "", "path to getBlock payload json") + slot = flag.Uint64("slot", 0, "block slot") + swapsOnly = flag.Bool("swaps-only", false, "only include transactions with swaps > 0") + ) + flag.Parse() + + if *filePath == "" || *slot == 0 { + fmt.Fprintln(os.Stderr, "usage: measure_tx_binary_block -file /path/block.json -slot 413539056") + os.Exit(2) + } + + raw, err := os.ReadFile(*filePath) + if err != nil { + fmt.Fprintf(os.Stderr, "read file: %v\n", err) + os.Exit(1) + } + + var response blockResponse + if err := json.Unmarshal(raw, &response); err != nil { + fmt.Fprintf(os.Stderr, "unmarshal block payload: %v\n", err) + os.Exit(1) + } + + var blockTime *uint64 + if response.Result.BlockTime != nil { + bt := uint64(*response.Result.BlockTime) + blockTime = &bt + } + + total := len(response.Result.Transactions) + converted := 0 + parsed := 0 + convertFailures := 0 + parseFailures := 0 + encodeFailures := 0 + filteredOutNoSwaps := 0 + var totalRawTxBytes int + var totalSingleEncoded int + minSingleEncoded := -1 + maxSingleEncoded := 0 + + parsedTxs := make([]pump_parser.Tx, 0, total) + for i, rawTx := range response.Result.Transactions { + transactionJSON, err := json.Marshal(rawTx.Transaction) + if err == nil { + totalRawTxBytes += len(transactionJSON) + } + rawTx.BlockTime = 0 + if blockTime != nil { + rawTx.BlockTime = int64(*blockTime) + } + rawTx.Slot = *slot + rawTx.IndexWithinBlock = int64(i) + converted++ + + tx, err := pump_parser.ParseRawTx(&rawTx) + if err != nil { + parseFailures++ + continue + } + if *swapsOnly && len(tx.Swaps) == 0 { + filteredOutNoSwaps++ + continue + } + parsed++ + + encoded, err := pump_parser.EncodeTxBinary(tx) + if err != nil { + encodeFailures++ + continue + } + size := len(encoded) + totalSingleEncoded += size + if minSingleEncoded == -1 || size < minSingleEncoded { + minSingleEncoded = size + } + if size > maxSingleEncoded { + maxSingleEncoded = size + } + parsedTxs = append(parsedTxs, *tx) + } + + batchEncoded, err := pump_parser.EncodeTxsBinary(parsedTxs) + if err != nil { + fmt.Fprintf(os.Stderr, "encode txs binary: %v\n", err) + os.Exit(1) + } + + avgSingleEncoded := 0 + if parsed > 0 { + avgSingleEncoded = totalSingleEncoded / parsed + } + + fmt.Printf("block_slot=%d\n", *slot) + fmt.Printf("payload_json_bytes=%d\n", len(raw)) + fmt.Printf("transactions_total=%d\n", total) + fmt.Printf("transactions_converted=%d\n", converted) + fmt.Printf("transactions_parsed=%d\n", parsed) + fmt.Printf("transactions_filtered_no_swaps=%d\n", filteredOutNoSwaps) + fmt.Printf("convert_failures=%d\n", convertFailures) + fmt.Printf("parse_failures=%d\n", parseFailures) + fmt.Printf("encode_failures=%d\n", encodeFailures) + fmt.Printf("raw_tx_total_bytes=%d\n", totalRawTxBytes) + fmt.Printf("single_txbinary_total_bytes=%d\n", totalSingleEncoded) + fmt.Printf("single_txbinary_avg_bytes=%d\n", avgSingleEncoded) + fmt.Printf("single_txbinary_min_bytes=%d\n", minSingleEncoded) + fmt.Printf("single_txbinary_max_bytes=%d\n", maxSingleEncoded) + fmt.Printf("batch_shared_table_bytes=%d\n", len(batchEncoded)) + if totalSingleEncoded > 0 { + fmt.Printf("batch_vs_single_saved_bytes=%d\n", totalSingleEncoded-len(batchEncoded)) + } +} diff --git a/enum.go b/enum.go index a7ab2ef..cc3f089 100644 --- a/enum.go +++ b/enum.go @@ -119,9 +119,11 @@ func GetConditionByProgram(program string) []string { } const ( - TxEventAddLP = "add" - TxEventRemoveLP = "remove" - TxEventBuy = "buy" - TxEventSell = "sell" - TxEventBurn = "burn" + TxEventAddLP = "add" + TxEventRemoveLP = "remove" + TxEventBuy = "buy" + TxEventSell = "sell" + TxEventBuyFailed = "buy_failed" + TxEventSellFailed = "sell_failed" + TxEventBurn = "burn" ) diff --git a/tx_binary.go b/tx_binary.go new file mode 100644 index 0000000..c5b8c61 --- /dev/null +++ b/tx_binary.go @@ -0,0 +1,2100 @@ +package pump_parser + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "iter" + "math" + "sort" + "strconv" + + "github.com/gagliardetto/solana-go" + "github.com/shopspring/decimal" +) + +const ( + txBinarySchemaVersionCurrent uint16 = 3 + txBinaryEnumVersionV1 uint16 = 1 + + txBinarySOLScale int32 = 9 + txBinaryCUPriceScale int32 = 6 +) + +var txBinaryMagic = [4]byte{'P', 'T', 'X', 'B'} +var txsBinaryMagic = [4]byte{'P', 'T', 'X', 'S'} + +type TxBinary struct { + SchemaVersion uint16 + EnumVersion uint16 + AddressTable []solana.PublicKey + Signer uint32 + Block uint64 + BlockIndex uint64 + TxHash *[64]byte + CuFee uint64 + Swaps []SwapBinary + Platform []PlatformBinary + MevAgent []MevAgentBinary + CUPrice uint64 + BeforeSolBalance float64 + AfterSOLBalance float64 + ComputeUnitsConsumed uint64 + CuLimit uint32 +} + +type SwapBinary struct { + Program string + Event string + + TxIndex int32 + + InstrIdx uint8 + InnerIdx uint8 + + Pool uint32 + BaseMint uint32 + QuoteMint uint32 + BaseTokenProgram uint32 + QuoteTokenProgram uint32 + Creator uint32 + + BaseMintDecimals uint8 + QuoteMintDecimals uint8 + + User uint32 + BaseAmount uint64 + QuoteAmount uint64 + + SwapMode SwapMode + FixedAmount uint64 + FixedAmountSide SwapAmountSide + FixedMint uint32 + LimitAmountType SwapLimitType + LimitAmount uint64 + LimitAmountSide SwapAmountSide + LimitMint uint32 + ActualLimitAmount uint64 + ActualLimitAmountSide SwapAmountSide + SlippageBps uint64 + + BaseReserve uint64 + QuoteReserve uint64 + Mayhem bool + Cashback bool + + UserBaseBalance uint64 + UserQuoteBalance uint64 + EntryContract uint32 + + MigrateToPool uint32 + MigrateTopProgram uint32 + + LpMint uint32 + + AfterSOLBalance float64 +} + +type TxsBinary struct { + SchemaVersion uint16 + EnumVersion uint16 + AddressTable []solana.PublicKey + Txs []TxBinary +} + +type TxsBinaryReaderSource interface { + OpenTxsBinaryReader() (io.ReadCloser, error) +} + +type PlatformBinary struct { + Platform string + PlatformFee uint64 +} + +type MevAgentBinary struct { + MevAgent string + MevAgentFee uint64 +} + +type txBinaryBytesSource struct { + data []byte +} + +type txsBinaryMergePlan struct { + schemaVersion uint16 + enumVersion uint16 + enumTable *txBinaryEnumTable + addressTable []solana.PublicKey + addressIndex *txBinaryAddressIndex + txCount uint32 +} + +func (s txBinaryBytesSource) OpenTxsBinaryReader() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(s.data)), nil +} + +func NewTxBinary(tx *Tx) (*TxBinary, error) { + if tx == nil { + return nil, fmt.Errorf("tx is nil") + } + + addressTable, err := txBinaryBuildAddressTable([]*Tx{tx}) + if err != nil { + return nil, err + } + addressIndex, err := newTxBinaryAddressIndex(addressTable) + if err != nil { + return nil, err + } + return newTxBinaryWithAddressTable(tx, addressTable, addressIndex) +} + +func NewTxsBinary(txs []Tx) (*TxsBinary, error) { + txPtrs := make([]*Tx, 0, len(txs)) + for i := range txs { + txPtrs = append(txPtrs, &txs[i]) + } + addressTable, err := txBinaryBuildAddressTable(txPtrs) + if err != nil { + return nil, err + } + addressIndex, err := newTxBinaryAddressIndex(addressTable) + if err != nil { + return nil, err + } + + out := &TxsBinary{ + SchemaVersion: txBinarySchemaVersionCurrent, + EnumVersion: txBinaryEnumVersionV1, + AddressTable: addressTable, + Txs: make([]TxBinary, 0, len(txPtrs)), + } + for i, tx := range txPtrs { + binaryTx, err := newTxBinaryWithAddressTable(tx, addressTable, addressIndex) + if err != nil { + return nil, fmt.Errorf("tx[%d]: %w", i, err) + } + out.Txs = append(out.Txs, *binaryTx) + } + return out, nil +} + +func newTxBinaryWithAddressTable(tx *Tx, addressTable []solana.PublicKey, addressIndex *txBinaryAddressIndex) (*TxBinary, error) { + if tx == nil { + return nil, fmt.Errorf("tx is nil") + } + + out := &TxBinary{ + SchemaVersion: txBinarySchemaVersionCurrent, + EnumVersion: txBinaryEnumVersionV1, + Block: tx.Block, + BlockIndex: tx.BlockIndex, + CuLimit: tx.CuLimit, + ComputeUnitsConsumed: tx.ComputeUnitsConsumed, + } + if tx.TxHash != nil { + txHash := *tx.TxHash + out.TxHash = &txHash + } + + var err error + if out.CuFee, err = txBinaryDecimalToUint64(tx.CuFee, "tx.cu_fee"); err != nil { + return nil, err + } + if out.CUPrice, err = txBinaryScaledDecimalToUint64(tx.CUPrice, txBinaryCUPriceScale, "tx.cu_price"); err != nil { + return nil, err + } + if out.BeforeSolBalance, err = txBinaryDecimalToFloat64(tx.BeforeSolBalance, txBinarySOLScale, "tx.before_sol_balance"); err != nil { + return nil, err + } + if out.AfterSOLBalance, err = txBinaryDecimalToFloat64(tx.AfterSOLBalance, txBinarySOLScale, "tx.after_sol_balance"); err != nil { + return nil, err + } + + out.Platform, err = txBinaryPlatformsFromTx(tx.Platform) + if err != nil { + return nil, err + } + out.MevAgent, err = txBinaryMevAgentsFromTx(tx.MevAgent) + if err != nil { + return nil, err + } + out.AddressTable = addressTable + if out.Signer, err = addressIndex.id(tx.Signer); err != nil { + return nil, fmt.Errorf("tx.signer: %w", err) + } + + out.Swaps = make([]SwapBinary, 0, len(tx.Swaps)) + for i, swap := range tx.Swaps { + encodedSwap, err := newSwapBinary(swap, i, addressIndex) + if err != nil { + return nil, err + } + out.Swaps = append(out.Swaps, encodedSwap) + } + + return out, nil +} + +func EncodeTxBinary(tx *Tx) ([]byte, error) { + binaryTx, err := NewTxBinary(tx) + if err != nil { + return nil, err + } + return binaryTx.MarshalBinary() +} + +func EncodeTxsBinary(txs []Tx) ([]byte, error) { + binaryTxs, err := NewTxsBinary(txs) + if err != nil { + return nil, err + } + return binaryTxs.MarshalBinary() +} + +func DecodeTxBinary(data []byte) (*Tx, error) { + var binaryTx TxBinary + if err := binaryTx.UnmarshalBinary(data); err != nil { + return nil, err + } + return binaryTx.ToTx() +} + +func DecodeTxsBinary(data []byte) ([]*Tx, error) { + var binaryTxs TxsBinary + if err := binaryTxs.UnmarshalBinary(data); err != nil { + return nil, err + } + return binaryTxs.ToTxs() +} + +func DecodeTxsBinaryReader(r io.Reader) iter.Seq2[*Tx, error] { + return func(yield func(*Tx, error) bool) { + if r == nil { + yield(nil, fmt.Errorf("txs binary reader is nil")) + return + } + + dec := txBinaryStreamDecoder{reader: r} + header, err := dec.readTxsBinaryHeader() + if err != nil { + yield(nil, err) + return + } + + for i := uint32(0); i < header.count; i++ { + tx := TxBinary{ + SchemaVersion: header.schemaVersion, + EnumVersion: header.enumVersion, + AddressTable: header.addressTable, + } + if err := txBinaryReadTxBody(&dec, &tx, header.enumTable, header.addressTable); err != nil { + yield(nil, fmt.Errorf("tx[%d]: %w", i, err)) + return + } + + decodedTx, err := tx.ToTx() + if err != nil { + yield(nil, fmt.Errorf("tx[%d]: %w", i, err)) + return + } + if !yield(decodedTx, nil) { + return + } + } + } +} + +func MergeTxsBinaryBytes(encodedBatches [][]byte) ([]byte, error) { + sources := make([]TxsBinaryReaderSource, 0, len(encodedBatches)) + for _, encoded := range encodedBatches { + sources = append(sources, txBinaryBytesSource{data: encoded}) + } + + var out bytes.Buffer + if err := MergeTxsBinarySourcesToWriter(sources, &out); err != nil { + return nil, err + } + return out.Bytes(), nil +} + +func MergeTxsBinarySourcesToWriter(sources []TxsBinaryReaderSource, w io.Writer) error { + if w == nil { + return fmt.Errorf("txs binary writer is nil") + } + + plan, err := txBinaryBuildMergePlan(sources) + if err != nil { + return err + } + + headerBytes, err := txBinaryMarshalTxsHeader(plan.schemaVersion, plan.enumVersion, plan.addressTable, plan.txCount) + if err != nil { + return err + } + if err := txBinaryWriteAll(w, headerBytes); err != nil { + return err + } + + for sourceIndex, source := range sources { + reader, err := source.OpenTxsBinaryReader() + if err != nil { + return fmt.Errorf("source[%d]: open reader: %w", sourceIndex, err) + } + + dec := txBinaryStreamDecoder{reader: reader} + batchIndex := 0 + for { + header, err := dec.readTxsBinaryHeaderOrEOF() + if err != nil { + closeErr := reader.Close() + if err == io.EOF { + if closeErr != nil { + return fmt.Errorf("source[%d]: close reader: %w", sourceIndex, closeErr) + } + break + } + return fmt.Errorf("source[%d].batch[%d]: %w", sourceIndex, batchIndex, err) + } + + for txIndex := uint32(0); txIndex < header.count; txIndex++ { + tx := TxBinary{ + SchemaVersion: header.schemaVersion, + EnumVersion: header.enumVersion, + AddressTable: header.addressTable, + } + if err := txBinaryReadTxBody(&dec, &tx, header.enumTable, header.addressTable); err != nil { + reader.Close() + return fmt.Errorf("source[%d].batch[%d].tx[%d]: %w", sourceIndex, batchIndex, txIndex, err) + } + if err := txBinaryRemapTxAddressTable(&tx, header.addressTable, plan.addressTable, plan.addressIndex); err != nil { + reader.Close() + return fmt.Errorf("source[%d].batch[%d].tx[%d]: %w", sourceIndex, batchIndex, txIndex, err) + } + + bodyBytes, err := txBinaryMarshalTxBody(&tx, plan.enumTable) + if err != nil { + reader.Close() + return fmt.Errorf("source[%d].batch[%d].tx[%d]: %w", sourceIndex, batchIndex, txIndex, err) + } + if err := txBinaryWriteAll(w, bodyBytes); err != nil { + reader.Close() + return fmt.Errorf("source[%d].batch[%d].tx[%d]: write merged body: %w", sourceIndex, batchIndex, txIndex, err) + } + } + batchIndex++ + } + } + + return nil +} + +func (tx *TxBinary) MarshalBinary() ([]byte, error) { + if tx == nil { + return nil, fmt.Errorf("tx binary is nil") + } + if tx.SchemaVersion != txBinarySchemaVersionCurrent { + return nil, fmt.Errorf("unsupported tx binary schema version: %d", tx.SchemaVersion) + } + + enumTable, err := txBinaryEnumTableByVersion(tx.EnumVersion) + if err != nil { + return nil, err + } + addressTable := tx.AddressTable + + enc := txBinaryEncoder{} + enc.writeBytes(txBinaryMagic[:]) + enc.writeUint16(tx.SchemaVersion) + enc.writeUint16(tx.EnumVersion) + if err := enc.writeAddressTable(addressTable); err != nil { + return nil, err + } + if err := enc.writeTxBinaryBody(tx, enumTable); err != nil { + return nil, err + } + + return enc.bytes(), nil +} + +func (txs *TxsBinary) MarshalBinary() ([]byte, error) { + if txs == nil { + return nil, fmt.Errorf("txs binary is nil") + } + if txs.SchemaVersion != txBinarySchemaVersionCurrent { + return nil, fmt.Errorf("unsupported tx binary schema version: %d", txs.SchemaVersion) + } + + enumTable, err := txBinaryEnumTableByVersion(txs.EnumVersion) + if err != nil { + return nil, err + } + + enc := txBinaryEncoder{} + enc.writeBytes(txsBinaryMagic[:]) + enc.writeUint16(txs.SchemaVersion) + enc.writeUint16(txs.EnumVersion) + if err := enc.writeAddressTable(txs.AddressTable); err != nil { + return nil, err + } + enc.writeUint32(uint32(len(txs.Txs))) + for i := range txs.Txs { + if err := enc.writeTxBinaryBody(&txs.Txs[i], enumTable); err != nil { + return nil, fmt.Errorf("tx[%d]: %w", i, err) + } + } + return enc.bytes(), nil +} + +func txBinaryMarshalTxsHeader(schemaVersion uint16, enumVersion uint16, addressTable []solana.PublicKey, txCount uint32) ([]byte, error) { + enc := txBinaryEncoder{} + enc.writeBytes(txsBinaryMagic[:]) + enc.writeUint16(schemaVersion) + enc.writeUint16(enumVersion) + if err := enc.writeAddressTable(addressTable); err != nil { + return nil, err + } + enc.writeUint32(txCount) + return enc.bytes(), nil +} + +func txBinaryMarshalTxBody(tx *TxBinary, enumTable *txBinaryEnumTable) ([]byte, error) { + enc := txBinaryEncoder{} + if err := enc.writeTxBinaryBody(tx, enumTable); err != nil { + return nil, err + } + return enc.bytes(), nil +} + +func (tx *TxBinary) UnmarshalBinary(data []byte) error { + dec := txBinaryDecoder{reader: bytes.NewReader(data)} + + magic, err := dec.readN(len(txBinaryMagic)) + if err != nil { + return err + } + if !bytes.Equal(magic, txBinaryMagic[:]) { + return fmt.Errorf("invalid tx binary magic") + } + + tx.SchemaVersion, err = dec.readUint16() + if err != nil { + return err + } + if tx.SchemaVersion != txBinarySchemaVersionCurrent { + return fmt.Errorf("unsupported tx binary schema version: %d", tx.SchemaVersion) + } + + tx.EnumVersion, err = dec.readUint16() + if err != nil { + return err + } + enumTable, err := txBinaryEnumTableByVersion(tx.EnumVersion) + if err != nil { + return err + } + tx.AddressTable, err = dec.readAddressTable() + if err != nil { + return err + } + if err := txBinaryReadTxBody(&dec, tx, enumTable, tx.AddressTable); err != nil { + return err + } + if dec.reader.Len() != 0 { + return fmt.Errorf("unexpected trailing tx binary data: %d bytes", dec.reader.Len()) + } + return nil +} + +func (txs *TxsBinary) UnmarshalBinary(data []byte) error { + dec := txBinaryDecoder{reader: bytes.NewReader(data)} + + magic, err := dec.readN(len(txsBinaryMagic)) + if err != nil { + return err + } + if !bytes.Equal(magic, txsBinaryMagic[:]) { + return fmt.Errorf("invalid txs binary magic") + } + + txs.SchemaVersion, err = dec.readUint16() + if err != nil { + return err + } + if txs.SchemaVersion != txBinarySchemaVersionCurrent { + return fmt.Errorf("unsupported tx binary schema version: %d", txs.SchemaVersion) + } + + txs.EnumVersion, err = dec.readUint16() + if err != nil { + return err + } + enumTable, err := txBinaryEnumTableByVersion(txs.EnumVersion) + if err != nil { + return err + } + txs.AddressTable, err = dec.readAddressTable() + if err != nil { + return err + } + + count, err := dec.readUint32() + if err != nil { + return err + } + txs.Txs = make([]TxBinary, 0, count) + for i := uint32(0); i < count; i++ { + tx := TxBinary{ + SchemaVersion: txs.SchemaVersion, + EnumVersion: txs.EnumVersion, + AddressTable: txs.AddressTable, + } + if err := txBinaryReadTxBody(&dec, &tx, enumTable, txs.AddressTable); err != nil { + return fmt.Errorf("tx[%d]: %w", i, err) + } + txs.Txs = append(txs.Txs, tx) + } + if dec.reader.Len() != 0 { + return fmt.Errorf("unexpected trailing txs binary data: %d bytes", dec.reader.Len()) + } + return nil +} + +func (tx *TxBinary) ToTx() (*Tx, error) { + if tx == nil { + return nil, nil + } + + signer, err := txBinaryAddressAt(tx.AddressTable, tx.Signer, "tx.signer") + if err != nil { + return nil, err + } + + out := &Tx{ + Signer: signer, + Block: tx.Block, + BlockIndex: tx.BlockIndex, + CuFee: decimal.NewFromUint64(tx.CuFee), + CUPrice: decimal.NewFromUint64(tx.CUPrice).Shift(-txBinaryCUPriceScale), + BeforeSolBalance: txBinaryFloat64ToDecimal(tx.BeforeSolBalance, txBinarySOLScale), + AfterSOLBalance: txBinaryFloat64ToDecimal(tx.AfterSOLBalance, txBinarySOLScale), + ComputeUnitsConsumed: tx.ComputeUnitsConsumed, + CuLimit: tx.CuLimit, + } + if tx.TxHash != nil { + txHash := *tx.TxHash + out.TxHash = &txHash + } + + if len(tx.Platform) > 0 { + out.Platform = make(map[string]platformInfo, len(tx.Platform)) + for _, platform := range tx.Platform { + out.Platform[platform.Platform] = platformInfo{ + Platform: platform.Platform, + PlatformFee: decimal.NewFromUint64(platform.PlatformFee).Shift(-txBinarySOLScale), + } + } + } + if len(tx.MevAgent) > 0 { + out.MevAgent = make(map[string]mevInfo, len(tx.MevAgent)) + for _, mevAgent := range tx.MevAgent { + out.MevAgent[mevAgent.MevAgent] = mevInfo{ + MevAgent: mevAgent.MevAgent, + MevAgentFee: decimal.NewFromUint64(mevAgent.MevAgentFee).Shift(-txBinarySOLScale), + } + } + } + if len(tx.Swaps) > 0 { + out.Swaps = make([]Swap, 0, len(tx.Swaps)) + for i, swap := range tx.Swaps { + decodedSwap, err := swap.toSwap(tx.AddressTable, i) + if err != nil { + return nil, err + } + out.Swaps = append(out.Swaps, decodedSwap) + } + } + + return out, nil +} + +func (txs *TxsBinary) ToTxs() ([]*Tx, error) { + if txs == nil { + return nil, nil + } + out := make([]*Tx, 0, len(txs.Txs)) + for i := range txs.Txs { + txs.Txs[i].AddressTable = txs.AddressTable + tx, err := txs.Txs[i].ToTx() + if err != nil { + return nil, fmt.Errorf("tx[%d]: %w", i, err) + } + out = append(out, tx) + } + return out, nil +} + +func newSwapBinary(swap Swap, index int, addressIndex *txBinaryAddressIndex) (SwapBinary, error) { + pool, err := addressIndex.id(swap.Pool) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].pool: %w", index, err) + } + baseMint, err := addressIndex.id(swap.BaseMint) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].base_mint: %w", index, err) + } + quoteMint, err := addressIndex.id(swap.QuoteMint) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].quote_mint: %w", index, err) + } + baseTokenProgram, err := addressIndex.id(swap.BaseTokenProgram) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].base_token_program: %w", index, err) + } + quoteTokenProgram, err := addressIndex.id(swap.QuoteTokenProgram) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].quote_token_program: %w", index, err) + } + creator, err := addressIndex.id(swap.Creator) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].creator: %w", index, err) + } + user, err := addressIndex.id(swap.User) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].user: %w", index, err) + } + fixedMint, err := addressIndex.id(swap.FixedMint) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].fixed_mint: %w", index, err) + } + limitMint, err := addressIndex.id(swap.LimitMint) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].limit_mint: %w", index, err) + } + entryContract, err := addressIndex.id(swap.EntryContract) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].entry_contract: %w", index, err) + } + migrateToPool, err := addressIndex.id(swap.MigrateToPool) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].migrate_to_pool: %w", index, err) + } + migrateTopProgram, err := addressIndex.id(swap.MigrateTopProgram) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].migrate_top_program: %w", index, err) + } + lpMint, err := addressIndex.id(swap.LpMint) + if err != nil { + return SwapBinary{}, fmt.Errorf("swap[%d].lp_mint: %w", index, err) + } + + out := SwapBinary{ + Program: swap.Program, + Event: swap.Event, + TxIndex: int32(swap.TxIndex), + InstrIdx: swap.InstrIdx, + InnerIdx: swap.InnerIdx, + Pool: pool, + BaseMint: baseMint, + QuoteMint: quoteMint, + BaseTokenProgram: baseTokenProgram, + QuoteTokenProgram: quoteTokenProgram, + Creator: creator, + BaseMintDecimals: swap.BaseMintDecimals, + QuoteMintDecimals: swap.QuoteMintDecimals, + User: user, + SwapMode: swap.SwapMode, + FixedAmountSide: swap.FixedAmountSide, + FixedMint: fixedMint, + LimitAmountType: swap.LimitAmountType, + LimitAmountSide: swap.LimitAmountSide, + LimitMint: limitMint, + ActualLimitAmountSide: swap.ActualLimitAmountSide, + Mayhem: swap.Mayhem, + Cashback: swap.Cashback, + EntryContract: entryContract, + MigrateToPool: migrateToPool, + MigrateTopProgram: migrateTopProgram, + LpMint: lpMint, + } + if swap.TxIndex > math.MaxInt32 || swap.TxIndex < math.MinInt32 { + return SwapBinary{}, fmt.Errorf("swap[%d].tx_index overflows int32: %d", index, swap.TxIndex) + } + + if out.BaseAmount, err = txBinaryDecimalToUint64(swap.BaseAmount, fmt.Sprintf("swap[%d].base_amount", index)); err != nil { + return SwapBinary{}, err + } + if out.QuoteAmount, err = txBinaryDecimalToUint64(swap.QuoteAmount, fmt.Sprintf("swap[%d].quote_amount", index)); err != nil { + return SwapBinary{}, err + } + if out.FixedAmount, err = txBinaryDecimalToUint64(swap.FixedAmount, fmt.Sprintf("swap[%d].fixed_amount", index)); err != nil { + return SwapBinary{}, err + } + if out.LimitAmount, err = txBinaryDecimalToUint64(swap.LimitAmount, fmt.Sprintf("swap[%d].limit_amount", index)); err != nil { + return SwapBinary{}, err + } + if out.ActualLimitAmount, err = txBinaryDecimalToUint64(swap.ActualLimitAmount, fmt.Sprintf("swap[%d].actual_limit_amount", index)); err != nil { + return SwapBinary{}, err + } + if out.SlippageBps, err = txBinaryRoundedDecimalToUint64(swap.SlippageBps, fmt.Sprintf("swap[%d].slippage_bps", index)); err != nil { + return SwapBinary{}, err + } + if out.BaseReserve, err = txBinaryDecimalToUint64(swap.BaseReserve, fmt.Sprintf("swap[%d].base_reserve", index)); err != nil { + return SwapBinary{}, err + } + if out.QuoteReserve, err = txBinaryDecimalToUint64(swap.QuoteReserve, fmt.Sprintf("swap[%d].quote_reserve", index)); err != nil { + return SwapBinary{}, err + } + if out.UserBaseBalance, err = txBinaryDecimalToUint64(swap.UserBaseBalance, fmt.Sprintf("swap[%d].user_base_balance", index)); err != nil { + return SwapBinary{}, err + } + if out.UserQuoteBalance, err = txBinaryDecimalToUint64(swap.UserQuoteBalance, fmt.Sprintf("swap[%d].user_quote_balance", index)); err != nil { + return SwapBinary{}, err + } + if out.AfterSOLBalance, err = txBinaryDecimalToFloat64(swap.AfterSOLBalance, txBinarySOLScale, fmt.Sprintf("swap[%d].after_sol_balance", index)); err != nil { + return SwapBinary{}, err + } + + return out, nil +} + +func (swap SwapBinary) toSwap(addressTable []solana.PublicKey, index int) (Swap, error) { + pool, err := txBinaryAddressAt(addressTable, swap.Pool, fmt.Sprintf("swap[%d].pool", index)) + if err != nil { + return Swap{}, err + } + baseMint, err := txBinaryAddressAt(addressTable, swap.BaseMint, fmt.Sprintf("swap[%d].base_mint", index)) + if err != nil { + return Swap{}, err + } + quoteMint, err := txBinaryAddressAt(addressTable, swap.QuoteMint, fmt.Sprintf("swap[%d].quote_mint", index)) + if err != nil { + return Swap{}, err + } + baseTokenProgram, err := txBinaryAddressAt(addressTable, swap.BaseTokenProgram, fmt.Sprintf("swap[%d].base_token_program", index)) + if err != nil { + return Swap{}, err + } + quoteTokenProgram, err := txBinaryAddressAt(addressTable, swap.QuoteTokenProgram, fmt.Sprintf("swap[%d].quote_token_program", index)) + if err != nil { + return Swap{}, err + } + creator, err := txBinaryAddressAt(addressTable, swap.Creator, fmt.Sprintf("swap[%d].creator", index)) + if err != nil { + return Swap{}, err + } + user, err := txBinaryAddressAt(addressTable, swap.User, fmt.Sprintf("swap[%d].user", index)) + if err != nil { + return Swap{}, err + } + fixedMint, err := txBinaryAddressAt(addressTable, swap.FixedMint, fmt.Sprintf("swap[%d].fixed_mint", index)) + if err != nil { + return Swap{}, err + } + limitMint, err := txBinaryAddressAt(addressTable, swap.LimitMint, fmt.Sprintf("swap[%d].limit_mint", index)) + if err != nil { + return Swap{}, err + } + entryContract, err := txBinaryAddressAt(addressTable, swap.EntryContract, fmt.Sprintf("swap[%d].entry_contract", index)) + if err != nil { + return Swap{}, err + } + migrateToPool, err := txBinaryAddressAt(addressTable, swap.MigrateToPool, fmt.Sprintf("swap[%d].migrate_to_pool", index)) + if err != nil { + return Swap{}, err + } + migrateTopProgram, err := txBinaryAddressAt(addressTable, swap.MigrateTopProgram, fmt.Sprintf("swap[%d].migrate_top_program", index)) + if err != nil { + return Swap{}, err + } + lpMint, err := txBinaryAddressAt(addressTable, swap.LpMint, fmt.Sprintf("swap[%d].lp_mint", index)) + if err != nil { + return Swap{}, err + } + + return Swap{ + Program: swap.Program, + Event: swap.Event, + TxIndex: int(swap.TxIndex), + InstrIdx: swap.InstrIdx, + InnerIdx: swap.InnerIdx, + Pool: pool, + BaseMint: baseMint, + QuoteMint: quoteMint, + BaseTokenProgram: baseTokenProgram, + QuoteTokenProgram: quoteTokenProgram, + Creator: creator, + BaseMintDecimals: swap.BaseMintDecimals, + QuoteMintDecimals: swap.QuoteMintDecimals, + User: user, + BaseAmount: decimal.NewFromUint64(swap.BaseAmount), + QuoteAmount: decimal.NewFromUint64(swap.QuoteAmount), + SwapMode: swap.SwapMode, + FixedAmount: decimal.NewFromUint64(swap.FixedAmount), + FixedAmountSide: swap.FixedAmountSide, + FixedMint: fixedMint, + LimitAmountType: swap.LimitAmountType, + LimitAmount: decimal.NewFromUint64(swap.LimitAmount), + LimitAmountSide: swap.LimitAmountSide, + LimitMint: limitMint, + ActualLimitAmount: decimal.NewFromUint64(swap.ActualLimitAmount), + ActualLimitAmountSide: swap.ActualLimitAmountSide, + SlippageBps: decimal.NewFromUint64(swap.SlippageBps), + BaseReserve: decimal.NewFromUint64(swap.BaseReserve), + QuoteReserve: decimal.NewFromUint64(swap.QuoteReserve), + Mayhem: swap.Mayhem, + Cashback: swap.Cashback, + UserBaseBalance: decimal.NewFromUint64(swap.UserBaseBalance), + UserQuoteBalance: decimal.NewFromUint64(swap.UserQuoteBalance), + EntryContract: entryContract, + MigrateToPool: migrateToPool, + MigrateTopProgram: migrateTopProgram, + LpMint: lpMint, + AfterSOLBalance: txBinaryFloat64ToDecimal(swap.AfterSOLBalance, txBinarySOLScale), + }, nil +} + +func txBinaryPlatformsFromTx(platforms map[string]platformInfo) ([]PlatformBinary, error) { + if len(platforms) == 0 { + return nil, nil + } + + keys := make([]string, 0, len(platforms)) + for key := range platforms { + keys = append(keys, key) + } + sort.Strings(keys) + + out := make([]PlatformBinary, 0, len(keys)) + for _, key := range keys { + platform := platforms[key] + platformFee, err := txBinaryScaledDecimalToUint64(platform.PlatformFee, txBinarySOLScale, fmt.Sprintf("platform[%s].fee", key)) + if err != nil { + return nil, err + } + out = append(out, PlatformBinary{ + Platform: key, + PlatformFee: platformFee, + }) + } + return out, nil +} + +func txBinaryMevAgentsFromTx(mevAgents map[string]mevInfo) ([]MevAgentBinary, error) { + if len(mevAgents) == 0 { + return nil, nil + } + + keys := make([]string, 0, len(mevAgents)) + for key := range mevAgents { + keys = append(keys, key) + } + sort.Strings(keys) + + out := make([]MevAgentBinary, 0, len(keys)) + for _, key := range keys { + mevAgent := mevAgents[key] + mevFee, err := txBinaryScaledDecimalToUint64(mevAgent.MevAgentFee, txBinarySOLScale, fmt.Sprintf("mev_agent[%s].fee", key)) + if err != nil { + return nil, err + } + out = append(out, MevAgentBinary{ + MevAgent: key, + MevAgentFee: mevFee, + }) + } + return out, nil +} + +func txBinaryBuildAddressTable(txs []*Tx) ([]solana.PublicKey, error) { + builder := txBinaryAddressTableBuilder{ + index: make(map[solana.PublicKey]struct{}), + } + + for txIndex, tx := range txs { + if tx == nil { + return nil, fmt.Errorf("tx[%d] is nil", txIndex) + } + if err := builder.add(tx.Signer); err != nil { + return nil, fmt.Errorf("tx[%d].signer: %w", txIndex, err) + } + for swapIndex, swap := range tx.Swaps { + for _, address := range []solana.PublicKey{ + swap.Pool, + swap.BaseMint, + swap.QuoteMint, + swap.BaseTokenProgram, + swap.QuoteTokenProgram, + swap.Creator, + swap.User, + swap.FixedMint, + swap.LimitMint, + swap.EntryContract, + swap.MigrateToPool, + swap.MigrateTopProgram, + swap.LpMint, + } { + if err := builder.add(address); err != nil { + return nil, fmt.Errorf("tx[%d].swap[%d] address table: %w", txIndex, swapIndex, err) + } + } + } + } + return builder.addresses, nil +} + +type txBinaryAddressTableBuilder struct { + addresses []solana.PublicKey + index map[solana.PublicKey]struct{} +} + +func (b *txBinaryAddressTableBuilder) add(address solana.PublicKey) error { + if _, ok := b.index[address]; ok { + return nil + } + if uint64(len(b.addresses)) >= uint64(math.MaxUint32) { + return fmt.Errorf("address table exceeds uint32 capacity") + } + b.addresses = append(b.addresses, address) + b.index[address] = struct{}{} + return nil +} + +type txBinaryAddressIndex struct { + index map[solana.PublicKey]uint32 +} + +func newTxBinaryAddressIndex(addresses []solana.PublicKey) (*txBinaryAddressIndex, error) { + if uint64(len(addresses)) > uint64(math.MaxUint32) { + return nil, fmt.Errorf("address table exceeds uint32 capacity") + } + index := make(map[solana.PublicKey]uint32, len(addresses)) + for i, address := range addresses { + if _, exists := index[address]; exists { + return nil, fmt.Errorf("duplicate address table entry: %s", address.String()) + } + index[address] = uint32(i) + } + return &txBinaryAddressIndex{index: index}, nil +} + +func (idx *txBinaryAddressIndex) id(address solana.PublicKey) (uint32, error) { + id, ok := idx.index[address] + if !ok { + return 0, fmt.Errorf("address not found in address table: %s", address.String()) + } + return id, nil +} + +func txBinaryAddressAt(addressTable []solana.PublicKey, index uint32, field string) (solana.PublicKey, error) { + if int(index) >= len(addressTable) { + return solana.PublicKey{}, fmt.Errorf("%s address index out of range: %d", field, index) + } + return addressTable[index], nil +} + +func txBinaryDecimalToUint64(value decimal.Decimal, field string) (uint64, error) { + if value.IsNegative() { + return 0, fmt.Errorf("%s must be >= 0, got %s", field, value.String()) + } + if !value.Equal(value.Truncate(0)) { + return 0, fmt.Errorf("%s must be an integer, got %s", field, value.String()) + } + + bigInt := value.BigInt() + if !bigInt.IsUint64() { + return 0, fmt.Errorf("%s overflows uint64: %s", field, value.String()) + } + return bigInt.Uint64(), nil +} + +func txBinaryScaledDecimalToUint64(value decimal.Decimal, scale int32, field string) (uint64, error) { + return txBinaryDecimalToUint64(value.Shift(scale), field) +} + +func txBinaryRoundedDecimalToUint64(value decimal.Decimal, field string) (uint64, error) { + return txBinaryDecimalToUint64(value.Round(0), field) +} + +func txBinaryDecimalToFloat64(value decimal.Decimal, scale int32, field string) (float64, error) { + rounded := value.Round(scale) + f, exact := rounded.Float64() + if !exact && math.IsInf(f, 0) { + return 0, fmt.Errorf("%s cannot be represented as float64: %s", field, value.String()) + } + return f, nil +} + +func txBinaryFloat64ToDecimal(value float64, scale int32) decimal.Decimal { + formatted := strconv.FormatFloat(value, 'f', int(scale), 64) + out, err := decimal.NewFromString(formatted) + if err != nil { + return decimal.Zero + } + return out +} + +type txBinaryEncoder struct { + buf bytes.Buffer +} + +func (enc *txBinaryEncoder) bytes() []byte { + return enc.buf.Bytes() +} + +func (enc *txBinaryEncoder) writeBool(value bool) { + if value { + enc.writeUint8(1) + return + } + enc.writeUint8(0) +} + +func (enc *txBinaryEncoder) writeUint8(value uint8) { + enc.buf.WriteByte(value) +} + +func (enc *txBinaryEncoder) writeUint16(value uint16) { + var raw [2]byte + binary.LittleEndian.PutUint16(raw[:], value) + enc.buf.Write(raw[:]) +} + +func (enc *txBinaryEncoder) writeUint32(value uint32) { + var raw [4]byte + binary.LittleEndian.PutUint32(raw[:], value) + enc.buf.Write(raw[:]) +} + +func (enc *txBinaryEncoder) writeUint64(value uint64) { + var raw [8]byte + binary.LittleEndian.PutUint64(raw[:], value) + enc.buf.Write(raw[:]) +} + +func (enc *txBinaryEncoder) writeFloat64(value float64) { + enc.writeUint64(math.Float64bits(value)) +} + +func (enc *txBinaryEncoder) writeInt32(value int32) { + enc.writeUint32(uint32(value)) +} + +func (enc *txBinaryEncoder) writeBytes(value []byte) { + enc.buf.Write(value) +} + +func (enc *txBinaryEncoder) writeAddressTable(addresses []solana.PublicKey) error { + if uint64(len(addresses)) > uint64(math.MaxUint32) { + return fmt.Errorf("address table exceeds uint32 capacity") + } + enc.writeUint32(uint32(len(addresses))) + for _, address := range addresses { + enc.writeBytes(address[:]) + } + return nil +} + +func (enc *txBinaryEncoder) writeTxBinaryBody(tx *TxBinary, enumTable *txBinaryEnumTable) error { + enc.writeUint32(tx.Signer) + enc.writeUint64(tx.Block) + enc.writeUint64(tx.BlockIndex) + enc.writeBool(tx.TxHash != nil) + if tx.TxHash != nil { + enc.writeBytes(tx.TxHash[:]) + } + enc.writeUint64(tx.CuFee) + enc.writeUint64(tx.CUPrice) + enc.writeFloat64(tx.BeforeSolBalance) + enc.writeFloat64(tx.AfterSOLBalance) + enc.writeUint64(tx.ComputeUnitsConsumed) + enc.writeUint32(tx.CuLimit) + if err := enc.writePlatformEntries(tx.Platform, enumTable); err != nil { + return err + } + if err := enc.writeMevAgentEntries(tx.MevAgent, enumTable); err != nil { + return err + } + if err := enc.writeSwaps(tx.Swaps, enumTable); err != nil { + return err + } + return nil +} + +func (enc *txBinaryEncoder) writePlatformEntries(entries []PlatformBinary, enumTable *txBinaryEnumTable) error { + enc.writeUint32(uint32(len(entries))) + for i, entry := range entries { + enumID, err := enumTable.platforms.id(entry.Platform) + if err != nil { + return fmt.Errorf("platform[%d]: %w", i, err) + } + enc.writeUint16(enumID) + enc.writeUint64(entry.PlatformFee) + } + return nil +} + +func (enc *txBinaryEncoder) writeMevAgentEntries(entries []MevAgentBinary, enumTable *txBinaryEnumTable) error { + enc.writeUint32(uint32(len(entries))) + for i, entry := range entries { + enumID, err := enumTable.mevAgents.id(entry.MevAgent) + if err != nil { + return fmt.Errorf("mev_agent[%d]: %w", i, err) + } + enc.writeUint16(enumID) + enc.writeUint64(entry.MevAgentFee) + } + return nil +} + +func (enc *txBinaryEncoder) writeSwaps(swaps []SwapBinary, enumTable *txBinaryEnumTable) error { + enc.writeUint32(uint32(len(swaps))) + for i, swap := range swaps { + programID, err := enumTable.programs.id(swap.Program) + if err != nil { + return fmt.Errorf("swap[%d].program: %w", i, err) + } + eventID, err := enumTable.events.id(swap.Event) + if err != nil { + return fmt.Errorf("swap[%d].event: %w", i, err) + } + + enc.writeUint16(programID) + enc.writeUint16(eventID) + enc.writeInt32(swap.TxIndex) + enc.writeUint8(swap.InstrIdx) + enc.writeUint8(swap.InnerIdx) + enc.writeUint32(swap.Pool) + enc.writeUint32(swap.BaseMint) + enc.writeUint32(swap.QuoteMint) + enc.writeUint32(swap.BaseTokenProgram) + enc.writeUint32(swap.QuoteTokenProgram) + enc.writeUint32(swap.Creator) + enc.writeUint8(swap.BaseMintDecimals) + enc.writeUint8(swap.QuoteMintDecimals) + enc.writeUint32(swap.User) + enc.writeUint64(swap.BaseAmount) + enc.writeUint64(swap.QuoteAmount) + enc.writeUint8(uint8(swap.SwapMode)) + enc.writeUint64(swap.FixedAmount) + enc.writeUint8(uint8(swap.FixedAmountSide)) + enc.writeUint32(swap.FixedMint) + enc.writeUint8(uint8(swap.LimitAmountType)) + enc.writeUint64(swap.LimitAmount) + enc.writeUint8(uint8(swap.LimitAmountSide)) + enc.writeUint32(swap.LimitMint) + enc.writeUint64(swap.ActualLimitAmount) + enc.writeUint8(uint8(swap.ActualLimitAmountSide)) + enc.writeUint64(swap.SlippageBps) + enc.writeUint64(swap.BaseReserve) + enc.writeUint64(swap.QuoteReserve) + enc.writeBool(swap.Mayhem) + enc.writeBool(swap.Cashback) + enc.writeUint64(swap.UserBaseBalance) + enc.writeUint64(swap.UserQuoteBalance) + enc.writeUint32(swap.EntryContract) + enc.writeUint32(swap.MigrateToPool) + enc.writeUint32(swap.MigrateTopProgram) + enc.writeUint32(swap.LpMint) + enc.writeFloat64(swap.AfterSOLBalance) + } + return nil +} + +type txBinaryDecoder struct { + reader *bytes.Reader +} + +type txBinaryStreamDecoder struct { + reader io.Reader +} + +type txBinaryBodyReader interface { + readBool() (bool, error) + readUint8() (uint8, error) + readUint16() (uint16, error) + readUint32() (uint32, error) + readUint64() (uint64, error) + readFloat64() (float64, error) + readInt32() (int32, error) + readN(int) ([]byte, error) +} + +type txsBinaryHeader struct { + schemaVersion uint16 + enumVersion uint16 + addressTable []solana.PublicKey + enumTable *txBinaryEnumTable + count uint32 +} + +func (dec *txBinaryDecoder) readBool() (bool, error) { + value, err := dec.readUint8() + if err != nil { + return false, err + } + switch value { + case 0: + return false, nil + case 1: + return true, nil + default: + return false, fmt.Errorf("invalid bool value: %d", value) + } +} + +func (dec *txBinaryDecoder) readUint8() (uint8, error) { + raw, err := dec.readN(1) + if err != nil { + return 0, err + } + return raw[0], nil +} + +func (dec *txBinaryDecoder) readUint16() (uint16, error) { + raw, err := dec.readN(2) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint16(raw), nil +} + +func (dec *txBinaryDecoder) readUint32() (uint32, error) { + raw, err := dec.readN(4) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint32(raw), nil +} + +func (dec *txBinaryDecoder) readUint64() (uint64, error) { + raw, err := dec.readN(8) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint64(raw), nil +} + +func (dec *txBinaryDecoder) readFloat64() (float64, error) { + value, err := dec.readUint64() + if err != nil { + return 0, err + } + return math.Float64frombits(value), nil +} + +func (dec *txBinaryDecoder) readInt32() (int32, error) { + value, err := dec.readUint32() + if err != nil { + return 0, err + } + return int32(value), nil +} + +func (dec *txBinaryDecoder) readAddressTable() ([]solana.PublicKey, error) { + return txBinaryReadAddressTable(dec) +} + +func (dec *txBinaryDecoder) readN(n int) ([]byte, error) { + out := make([]byte, n) + if _, err := io.ReadFull(dec.reader, out); err != nil { + return nil, err + } + return out, nil +} + +func (dec *txBinaryDecoder) readPlatformEntries(enumTable *txBinaryEnumTable) ([]PlatformBinary, error) { + return txBinaryReadPlatformEntries(dec, enumTable) +} + +func (dec *txBinaryDecoder) readMevAgentEntries(enumTable *txBinaryEnumTable) ([]MevAgentBinary, error) { + return txBinaryReadMevAgentEntries(dec, enumTable) +} + +func (dec *txBinaryDecoder) readSwaps(enumTable *txBinaryEnumTable, _ []solana.PublicKey) ([]SwapBinary, error) { + return txBinaryReadSwaps(dec, enumTable) +} + +func (dec *txBinaryDecoder) readTxBinaryBody(tx *TxBinary, enumTable *txBinaryEnumTable, addressTable []solana.PublicKey) error { + return txBinaryReadTxBody(dec, tx, enumTable, addressTable) +} + +func (dec *txBinaryStreamDecoder) readBool() (bool, error) { + value, err := dec.readUint8() + if err != nil { + return false, err + } + switch value { + case 0: + return false, nil + case 1: + return true, nil + default: + return false, fmt.Errorf("invalid bool value: %d", value) + } +} + +func (dec *txBinaryStreamDecoder) readUint8() (uint8, error) { + raw, err := dec.readN(1) + if err != nil { + return 0, err + } + return raw[0], nil +} + +func (dec *txBinaryStreamDecoder) readUint16() (uint16, error) { + raw, err := dec.readN(2) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint16(raw), nil +} + +func (dec *txBinaryStreamDecoder) readUint32() (uint32, error) { + raw, err := dec.readN(4) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint32(raw), nil +} + +func (dec *txBinaryStreamDecoder) readUint64() (uint64, error) { + raw, err := dec.readN(8) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint64(raw), nil +} + +func (dec *txBinaryStreamDecoder) readFloat64() (float64, error) { + value, err := dec.readUint64() + if err != nil { + return 0, err + } + return math.Float64frombits(value), nil +} + +func (dec *txBinaryStreamDecoder) readInt32() (int32, error) { + value, err := dec.readUint32() + if err != nil { + return 0, err + } + return int32(value), nil +} + +func (dec *txBinaryStreamDecoder) readAddressTable() ([]solana.PublicKey, error) { + return txBinaryReadAddressTable(dec) +} + +func (dec *txBinaryStreamDecoder) readN(n int) ([]byte, error) { + out := make([]byte, n) + if _, err := io.ReadFull(dec.reader, out); err != nil { + return nil, err + } + return out, nil +} + +func (dec *txBinaryStreamDecoder) readTxsBinaryHeader() (*txsBinaryHeader, error) { + magic, err := dec.readN(len(txsBinaryMagic)) + if err != nil { + return nil, err + } + if !bytes.Equal(magic, txsBinaryMagic[:]) { + return nil, fmt.Errorf("invalid txs binary magic") + } + + schemaVersion, err := dec.readUint16() + if err != nil { + return nil, err + } + if schemaVersion != txBinarySchemaVersionCurrent { + return nil, fmt.Errorf("unsupported tx binary schema version: %d", schemaVersion) + } + + enumVersion, err := dec.readUint16() + if err != nil { + return nil, err + } + enumTable, err := txBinaryEnumTableByVersion(enumVersion) + if err != nil { + return nil, err + } + + addressTable, err := dec.readAddressTable() + if err != nil { + return nil, err + } + + count, err := dec.readUint32() + if err != nil { + return nil, err + } + + return &txsBinaryHeader{ + schemaVersion: schemaVersion, + enumVersion: enumVersion, + addressTable: addressTable, + enumTable: enumTable, + count: count, + }, nil +} + +func (dec *txBinaryStreamDecoder) readNOrEOF(n int) ([]byte, error) { + out := make([]byte, n) + readN, err := io.ReadFull(dec.reader, out) + if err != nil { + if err == io.EOF && readN == 0 { + return nil, io.EOF + } + return nil, err + } + return out, nil +} + +func (dec *txBinaryStreamDecoder) readTxsBinaryHeaderOrEOF() (*txsBinaryHeader, error) { + magic, err := dec.readNOrEOF(len(txsBinaryMagic)) + if err != nil { + return nil, err + } + if !bytes.Equal(magic, txsBinaryMagic[:]) { + return nil, fmt.Errorf("invalid txs binary magic") + } + + schemaVersion, err := dec.readUint16() + if err != nil { + return nil, err + } + if schemaVersion != txBinarySchemaVersionCurrent { + return nil, fmt.Errorf("unsupported tx binary schema version: %d", schemaVersion) + } + + enumVersion, err := dec.readUint16() + if err != nil { + return nil, err + } + enumTable, err := txBinaryEnumTableByVersion(enumVersion) + if err != nil { + return nil, err + } + + addressTable, err := dec.readAddressTable() + if err != nil { + return nil, err + } + + count, err := dec.readUint32() + if err != nil { + return nil, err + } + + return &txsBinaryHeader{ + schemaVersion: schemaVersion, + enumVersion: enumVersion, + addressTable: addressTable, + enumTable: enumTable, + count: count, + }, nil +} + +func txBinaryReadAddressTable(dec txBinaryBodyReader) ([]solana.PublicKey, error) { + count, err := dec.readUint32() + if err != nil { + return nil, err + } + addresses := make([]solana.PublicKey, 0, count) + for i := uint32(0); i < count; i++ { + raw, err := dec.readN(solana.PublicKeyLength) + if err != nil { + return nil, err + } + var publicKey solana.PublicKey + copy(publicKey[:], raw) + addresses = append(addresses, publicKey) + } + return addresses, nil +} + +func txBinaryReadPlatformEntries(dec txBinaryBodyReader, enumTable *txBinaryEnumTable) ([]PlatformBinary, error) { + count, err := dec.readUint32() + if err != nil { + return nil, err + } + out := make([]PlatformBinary, 0, count) + for i := uint32(0); i < count; i++ { + enumID, err := dec.readUint16() + if err != nil { + return nil, err + } + platform, err := enumTable.platforms.value(enumID) + if err != nil { + return nil, fmt.Errorf("platform[%d]: %w", i, err) + } + fee, err := dec.readUint64() + if err != nil { + return nil, err + } + out = append(out, PlatformBinary{ + Platform: platform, + PlatformFee: fee, + }) + } + return out, nil +} + +func txBinaryReadMevAgentEntries(dec txBinaryBodyReader, enumTable *txBinaryEnumTable) ([]MevAgentBinary, error) { + count, err := dec.readUint32() + if err != nil { + return nil, err + } + out := make([]MevAgentBinary, 0, count) + for i := uint32(0); i < count; i++ { + enumID, err := dec.readUint16() + if err != nil { + return nil, err + } + mevAgent, err := enumTable.mevAgents.value(enumID) + if err != nil { + return nil, fmt.Errorf("mev_agent[%d]: %w", i, err) + } + fee, err := dec.readUint64() + if err != nil { + return nil, err + } + out = append(out, MevAgentBinary{ + MevAgent: mevAgent, + MevAgentFee: fee, + }) + } + return out, nil +} + +func txBinaryReadSwaps(dec txBinaryBodyReader, enumTable *txBinaryEnumTable) ([]SwapBinary, error) { + count, err := dec.readUint32() + if err != nil { + return nil, err + } + out := make([]SwapBinary, 0, count) + for i := uint32(0); i < count; i++ { + programID, err := dec.readUint16() + if err != nil { + return nil, err + } + program, err := enumTable.programs.value(programID) + if err != nil { + return nil, fmt.Errorf("swap[%d].program: %w", i, err) + } + eventID, err := dec.readUint16() + if err != nil { + return nil, err + } + event, err := enumTable.events.value(eventID) + if err != nil { + return nil, fmt.Errorf("swap[%d].event: %w", i, err) + } + + swap := SwapBinary{ + Program: program, + Event: event, + } + + if swap.TxIndex, err = dec.readInt32(); err != nil { + return nil, err + } + if swap.InstrIdx, err = dec.readUint8(); err != nil { + return nil, err + } + if swap.InnerIdx, err = dec.readUint8(); err != nil { + return nil, err + } + if swap.Pool, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.BaseMint, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.QuoteMint, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.BaseTokenProgram, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.QuoteTokenProgram, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.Creator, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.BaseMintDecimals, err = dec.readUint8(); err != nil { + return nil, err + } + if swap.QuoteMintDecimals, err = dec.readUint8(); err != nil { + return nil, err + } + if swap.User, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.BaseAmount, err = dec.readUint64(); err != nil { + return nil, err + } + if swap.QuoteAmount, err = dec.readUint64(); err != nil { + return nil, err + } + swapMode, err := dec.readUint8() + if err != nil { + return nil, err + } + swap.SwapMode = SwapMode(swapMode) + if swap.FixedAmount, err = dec.readUint64(); err != nil { + return nil, err + } + fixedAmountSide, err := dec.readUint8() + if err != nil { + return nil, err + } + swap.FixedAmountSide = SwapAmountSide(fixedAmountSide) + if swap.FixedMint, err = dec.readUint32(); err != nil { + return nil, err + } + limitType, err := dec.readUint8() + if err != nil { + return nil, err + } + swap.LimitAmountType = SwapLimitType(limitType) + if swap.LimitAmount, err = dec.readUint64(); err != nil { + return nil, err + } + limitAmountSide, err := dec.readUint8() + if err != nil { + return nil, err + } + swap.LimitAmountSide = SwapAmountSide(limitAmountSide) + if swap.LimitMint, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.ActualLimitAmount, err = dec.readUint64(); err != nil { + return nil, err + } + actualLimitAmountSide, err := dec.readUint8() + if err != nil { + return nil, err + } + swap.ActualLimitAmountSide = SwapAmountSide(actualLimitAmountSide) + if swap.SlippageBps, err = dec.readUint64(); err != nil { + return nil, err + } + if swap.BaseReserve, err = dec.readUint64(); err != nil { + return nil, err + } + if swap.QuoteReserve, err = dec.readUint64(); err != nil { + return nil, err + } + if swap.Mayhem, err = dec.readBool(); err != nil { + return nil, err + } + if swap.Cashback, err = dec.readBool(); err != nil { + return nil, err + } + if swap.UserBaseBalance, err = dec.readUint64(); err != nil { + return nil, err + } + if swap.UserQuoteBalance, err = dec.readUint64(); err != nil { + return nil, err + } + if swap.EntryContract, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.MigrateToPool, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.MigrateTopProgram, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.LpMint, err = dec.readUint32(); err != nil { + return nil, err + } + if swap.AfterSOLBalance, err = dec.readFloat64(); err != nil { + return nil, err + } + out = append(out, swap) + } + return out, nil +} + +func txBinaryReadTxBody(dec txBinaryBodyReader, tx *TxBinary, enumTable *txBinaryEnumTable, addressTable []solana.PublicKey) error { + var err error + tx.AddressTable = addressTable + if tx.Signer, err = dec.readUint32(); err != nil { + return err + } + if tx.Block, err = dec.readUint64(); err != nil { + return err + } + if tx.BlockIndex, err = dec.readUint64(); err != nil { + return err + } + + hasTxHash, err := dec.readBool() + if err != nil { + return err + } + if hasTxHash { + rawHash, err := dec.readN(64) + if err != nil { + return err + } + var txHash [64]byte + copy(txHash[:], rawHash) + tx.TxHash = &txHash + } else { + tx.TxHash = nil + } + + if tx.CuFee, err = dec.readUint64(); err != nil { + return err + } + if tx.CUPrice, err = dec.readUint64(); err != nil { + return err + } + if tx.BeforeSolBalance, err = dec.readFloat64(); err != nil { + return err + } + if tx.AfterSOLBalance, err = dec.readFloat64(); err != nil { + return err + } + if tx.ComputeUnitsConsumed, err = dec.readUint64(); err != nil { + return err + } + if tx.CuLimit, err = dec.readUint32(); err != nil { + return err + } + if tx.Platform, err = txBinaryReadPlatformEntries(dec, enumTable); err != nil { + return err + } + if tx.MevAgent, err = txBinaryReadMevAgentEntries(dec, enumTable); err != nil { + return err + } + if tx.Swaps, err = txBinaryReadSwaps(dec, enumTable); err != nil { + return err + } + return nil +} + +func txBinaryBuildMergePlan(sources []TxsBinaryReaderSource) (*txsBinaryMergePlan, error) { + if len(sources) == 0 { + return nil, fmt.Errorf("txs binary sources are empty") + } + + builder := txBinaryAddressTableBuilder{ + index: make(map[solana.PublicKey]struct{}), + } + plan := &txsBinaryMergePlan{} + hasBatch := false + + for sourceIndex, source := range sources { + if source == nil { + return nil, fmt.Errorf("source[%d] is nil", sourceIndex) + } + + reader, err := source.OpenTxsBinaryReader() + if err != nil { + return nil, fmt.Errorf("source[%d]: open reader: %w", sourceIndex, err) + } + + dec := txBinaryStreamDecoder{reader: reader} + batchIndex := 0 + for { + header, err := dec.readTxsBinaryHeaderOrEOF() + if err != nil { + closeErr := reader.Close() + if err == io.EOF { + if closeErr != nil { + return nil, fmt.Errorf("source[%d]: close reader: %w", sourceIndex, closeErr) + } + break + } + return nil, fmt.Errorf("source[%d].batch[%d]: %w", sourceIndex, batchIndex, err) + } + + if !hasBatch { + plan.schemaVersion = header.schemaVersion + plan.enumVersion = header.enumVersion + plan.enumTable = header.enumTable + hasBatch = true + } else { + if header.schemaVersion != plan.schemaVersion { + reader.Close() + return nil, fmt.Errorf("source[%d].batch[%d]: schema version mismatch: got %d want %d", sourceIndex, batchIndex, header.schemaVersion, plan.schemaVersion) + } + if header.enumVersion != plan.enumVersion { + reader.Close() + return nil, fmt.Errorf("source[%d].batch[%d]: enum version mismatch: got %d want %d", sourceIndex, batchIndex, header.enumVersion, plan.enumVersion) + } + } + + for addressIndex, address := range header.addressTable { + if err := builder.add(address); err != nil { + reader.Close() + return nil, fmt.Errorf("source[%d].batch[%d].address[%d]: %w", sourceIndex, batchIndex, addressIndex, err) + } + } + + if uint64(plan.txCount)+uint64(header.count) > uint64(math.MaxUint32) { + reader.Close() + return nil, fmt.Errorf("merged tx count exceeds uint32 capacity") + } + plan.txCount += header.count + + for txIndex := uint32(0); txIndex < header.count; txIndex++ { + tx := TxBinary{ + SchemaVersion: header.schemaVersion, + EnumVersion: header.enumVersion, + AddressTable: header.addressTable, + } + if err := txBinaryReadTxBody(&dec, &tx, header.enumTable, header.addressTable); err != nil { + reader.Close() + return nil, fmt.Errorf("source[%d].batch[%d].tx[%d]: %w", sourceIndex, batchIndex, txIndex, err) + } + } + batchIndex++ + } + } + + if !hasBatch { + return nil, fmt.Errorf("no txs binary batches found") + } + + addressIndex, err := newTxBinaryAddressIndex(builder.addresses) + if err != nil { + return nil, err + } + plan.addressTable = builder.addresses + plan.addressIndex = addressIndex + return plan, nil +} + +func txBinaryRemapTxAddressTable(tx *TxBinary, fromAddressTable []solana.PublicKey, toAddressTable []solana.PublicKey, toAddressIndex *txBinaryAddressIndex) error { + var err error + if tx.Signer, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Signer, "tx.signer"); err != nil { + return err + } + + for i := range tx.Swaps { + if tx.Swaps[i].Pool, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].Pool, fmt.Sprintf("swap[%d].pool", i)); err != nil { + return err + } + if tx.Swaps[i].BaseMint, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].BaseMint, fmt.Sprintf("swap[%d].base_mint", i)); err != nil { + return err + } + if tx.Swaps[i].QuoteMint, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].QuoteMint, fmt.Sprintf("swap[%d].quote_mint", i)); err != nil { + return err + } + if tx.Swaps[i].BaseTokenProgram, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].BaseTokenProgram, fmt.Sprintf("swap[%d].base_token_program", i)); err != nil { + return err + } + if tx.Swaps[i].QuoteTokenProgram, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].QuoteTokenProgram, fmt.Sprintf("swap[%d].quote_token_program", i)); err != nil { + return err + } + if tx.Swaps[i].Creator, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].Creator, fmt.Sprintf("swap[%d].creator", i)); err != nil { + return err + } + if tx.Swaps[i].User, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].User, fmt.Sprintf("swap[%d].user", i)); err != nil { + return err + } + if tx.Swaps[i].FixedMint, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].FixedMint, fmt.Sprintf("swap[%d].fixed_mint", i)); err != nil { + return err + } + if tx.Swaps[i].LimitMint, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].LimitMint, fmt.Sprintf("swap[%d].limit_mint", i)); err != nil { + return err + } + if tx.Swaps[i].EntryContract, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].EntryContract, fmt.Sprintf("swap[%d].entry_contract", i)); err != nil { + return err + } + if tx.Swaps[i].MigrateToPool, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].MigrateToPool, fmt.Sprintf("swap[%d].migrate_to_pool", i)); err != nil { + return err + } + if tx.Swaps[i].MigrateTopProgram, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].MigrateTopProgram, fmt.Sprintf("swap[%d].migrate_top_program", i)); err != nil { + return err + } + if tx.Swaps[i].LpMint, err = txBinaryRemapAddressRef(fromAddressTable, toAddressIndex, tx.Swaps[i].LpMint, fmt.Sprintf("swap[%d].lp_mint", i)); err != nil { + return err + } + } + + tx.AddressTable = toAddressTable + return nil +} + +func txBinaryRemapAddressRef(fromAddressTable []solana.PublicKey, toAddressIndex *txBinaryAddressIndex, fromRef uint32, field string) (uint32, error) { + address, err := txBinaryAddressAt(fromAddressTable, fromRef, field) + if err != nil { + return 0, err + } + return toAddressIndex.id(address) +} + +func txBinaryWriteAll(w io.Writer, data []byte) error { + for len(data) > 0 { + written, err := w.Write(data) + if err != nil { + return err + } + if written == 0 { + return io.ErrShortWrite + } + data = data[written:] + } + return nil +} + +type txBinaryEnumTable struct { + version uint16 + programs txBinaryEnumSet + events txBinaryEnumSet + platforms txBinaryEnumSet + mevAgents txBinaryEnumSet +} + +type txBinaryEnumSet struct { + name string + values []string + ids map[string]uint16 +} + +var txBinaryEnumTables = map[uint16]*txBinaryEnumTable{ + txBinaryEnumVersionV1: newTxBinaryEnumTable( + txBinaryEnumVersionV1, + "program", + []string{ + "", + SolProgramPump, + SolProgramRaydiumV4, + SolProgramRaydiumCLMM, + SolProgramRaydiumCPMM, + SolProgramMeteoraDLMM, + SolProgramOrcaWhirPool, + SolProgramPumpAMM, + SolProgramMeteoraAmmV2, + SolProgramMeteoraBondingCurve, + SolProgramMeteoraPools, + SolProgramRaydiumLaunchLab, + SolProgramRaydiumLaunchLabBonk, + }, + "event", + []string{ + "", + TxEventAddLP, + TxEventRemoveLP, + TxEventBuy, + TxEventSell, + TxEventBuyFailed, + TxEventSellFailed, + TxEventBurn, + }, + "platform", + []string{ + "", + PlatformGMGN, + PlatformPhoton, + PlatformAxiom, + PlatformPepe, + PlatformBullX, + PlatformBanana, + PlatformTrojan, + PlatformRaybot, + PlatformMoonshot, + PlatformMEVX, + PlatformTradeWiz, + PlatformSolTradingBot, + PlatformMoonshotMoney, + PlatformMaestro, + PlatformBonkBot, + PlatformPadre, + PlatformDexScreener, + PlatformFake, + PlatformNone, + }, + "mev_agent", + []string{ + "", + MevAgentJito, + MevAgent0slot, + MevAgentBlocxRoute, + MevAgentNozomi, + MevAgentNextBlock, + MevAgentHelius, + MevAgentNode1, + MevAgentFlashBlock, + MevAgentUnknown, + MevAgentBlockRazor, + MevAgentFast, + MevAgentSoyas, + MevAgentStellium, + MevAgentAstralane, + MevagentFa1con, + MevagentBlocksprint, + MevAgentMoon, + MevAgentSpeedlanding, + MevAgentAllenhark, + MevAgentRaiden, + }, + ), +} + +func txBinaryEnumTableByVersion(version uint16) (*txBinaryEnumTable, error) { + table, ok := txBinaryEnumTables[version] + if !ok { + return nil, fmt.Errorf("unsupported tx binary enum version: %d", version) + } + return table, nil +} + +func newTxBinaryEnumTable( + version uint16, + programName string, + programs []string, + eventName string, + events []string, + platformName string, + platforms []string, + mevAgentName string, + mevAgents []string, +) *txBinaryEnumTable { + return &txBinaryEnumTable{ + version: version, + programs: newTxBinaryEnumSet(programName, programs), + events: newTxBinaryEnumSet(eventName, events), + platforms: newTxBinaryEnumSet(platformName, platforms), + mevAgents: newTxBinaryEnumSet(mevAgentName, mevAgents), + } +} + +func newTxBinaryEnumSet(name string, values []string) txBinaryEnumSet { + ids := make(map[string]uint16, len(values)) + for i, value := range values { + if _, exists := ids[value]; exists { + panic(fmt.Sprintf("duplicate %s enum value: %q", name, value)) + } + ids[value] = uint16(i) + } + return txBinaryEnumSet{ + name: name, + values: values, + ids: ids, + } +} + +func (set txBinaryEnumSet) id(value string) (uint16, error) { + id, ok := set.ids[value] + if !ok { + return 0, fmt.Errorf("unsupported %s enum value %q for versioned tx binary", set.name, value) + } + return id, nil +} + +func (set txBinaryEnumSet) value(id uint16) (string, error) { + if int(id) >= len(set.values) { + return "", fmt.Errorf("unknown %s enum id %d", set.name, id) + } + return set.values[id], nil +} diff --git a/tx_binary_realdata_test.go b/tx_binary_realdata_test.go new file mode 100644 index 0000000..dfa8f79 --- /dev/null +++ b/tx_binary_realdata_test.go @@ -0,0 +1,146 @@ +package pump_parser + +import ( + "encoding/json" + "os" + "path/filepath" + "sort" + "strings" + "testing" + + "github.com/gagliardetto/solana-go/rpc" +) + +func TestTxBinaryRealFixtureSizes(t *testing.T) { + fixtures, err := filepath.Glob(filepath.Join("testdata", "rpc", "*.json")) + if err != nil { + t.Fatalf("glob fixtures: %v", err) + } + if len(fixtures) == 0 { + t.Fatal("no rpc fixtures found") + } + sort.Strings(fixtures) + + type sizeResult struct { + name string + swaps int + platforms int + mevAgents int + addresses int + encodedBytes int + fixtureBytes int + txBinaryBytes int + } + + results := make([]sizeResult, 0, len(fixtures)) + totalEncoded := 0 + + for _, fixture := range fixtures { + tx, rawTxBytesLen, fixtureBytesLen := mustParseRPCFixtureTxForBinarySize(t, fixture) + binaryTx, err := NewTxBinary(tx) + if err != nil { + t.Fatalf("build tx binary fixture %s: %v", fixture, err) + } + encoded, err := binaryTx.MarshalBinary() + if err != nil { + t.Fatalf("encode fixture %s: %v", fixture, err) + } + + result := sizeResult{ + name: strings.TrimSuffix(filepath.Base(fixture), filepath.Ext(fixture)), + swaps: len(tx.Swaps), + platforms: len(tx.Platform), + mevAgents: len(tx.MevAgent), + addresses: len(binaryTx.AddressTable), + encodedBytes: len(encoded), + fixtureBytes: fixtureBytesLen, + txBinaryBytes: rawTxBytesLen, + } + results = append(results, result) + totalEncoded += result.encodedBytes + } + + for _, result := range results { + t.Logf( + "%s encoded=%dB swaps=%d platforms=%d mev=%d addresses=%d fixture_json=%dB raw_tx=%dB", + result.name, + result.encodedBytes, + result.swaps, + result.platforms, + result.mevAgents, + result.addresses, + result.fixtureBytes, + result.txBinaryBytes, + ) + } + + minResult := results[0] + maxResult := results[0] + for _, result := range results[1:] { + if result.encodedBytes < minResult.encodedBytes { + minResult = result + } + if result.encodedBytes > maxResult.encodedBytes { + maxResult = result + } + } + + t.Logf( + "summary fixtures=%d avg=%dB min=%dB(%s) max=%dB(%s)", + len(results), + totalEncoded/len(results), + minResult.encodedBytes, + minResult.name, + maxResult.encodedBytes, + maxResult.name, + ) +} + +func mustParseRPCFixtureTxForBinarySize(t *testing.T, fixturePath string) (*Tx, int, int) { + t.Helper() + + raw, err := os.ReadFile(fixturePath) + if err != nil { + t.Fatalf("read fixture %s: %v", fixturePath, err) + } + + var response struct { + Result *rpc.GetTransactionResult `json:"result"` + } + if err := json.Unmarshal(raw, &response); err != nil { + t.Fatalf("unmarshal fixture %s: %v", fixturePath, err) + } + if response.Result == nil || response.Result.Transaction == nil || response.Result.Meta == nil { + t.Fatalf("fixture %s is missing transaction data", fixturePath) + } + + rawBinary := response.Result.Transaction.GetBinary() + if len(rawBinary) == 0 { + t.Fatalf("fixture %s has empty transaction bytes", fixturePath) + } + + txWithMeta := rpc.TransactionWithMeta{ + Slot: response.Result.Slot, + BlockTime: response.Result.BlockTime, + Transaction: rpc.DataBytesOrJSONFromBytes(rawBinary), + Meta: response.Result.Meta, + Version: response.Result.Version, + } + + var blockTime *uint64 + if response.Result.BlockTime != nil { + bt := uint64(*response.Result.BlockTime) + blockTime = &bt + } + + rawTx, err := FromRpcTransactionWithMeta(txWithMeta, blockTime, response.Result.Slot, 0) + if err != nil { + t.Fatalf("convert fixture %s: %v", fixturePath, err) + } + + tx, err := ParseRawTx(rawTx) + if err != nil { + t.Fatalf("parse fixture %s: %v", fixturePath, err) + } + return tx, len(rawBinary), len(raw) +} diff --git a/tx_binary_test.go b/tx_binary_test.go new file mode 100644 index 0000000..fb69dfd --- /dev/null +++ b/tx_binary_test.go @@ -0,0 +1,627 @@ +package pump_parser + +import ( + "bytes" + "io" + "testing" + + "github.com/gagliardetto/solana-go" + "github.com/shopspring/decimal" +) + +func TestTxBinaryRoundTrip(t *testing.T) { + txHash := [64]byte{} + for i := range txHash { + txHash[i] = byte(i + 1) + } + + original := &Tx{ + Signer: mustPubKey("So11111111111111111111111111111111111111112"), + Block: 123456789, + BlockIndex: 42, + TxHash: &txHash, + CuFee: decimal.NewFromInt(5000), + CUPrice: decimal.RequireFromString("0.123456"), + BeforeSolBalance: decimal.RequireFromString("1.500000000"), + AfterSOLBalance: decimal.RequireFromString("1.234567890"), + ComputeUnitsConsumed: 345678, + CuLimit: 400000, + Platform: map[string]platformInfo{ + PlatformGMGN: { + Platform: PlatformGMGN, + PlatformFee: decimal.RequireFromString("0.010000000"), + }, + PlatformPhoton: { + Platform: PlatformPhoton, + PlatformFee: decimal.RequireFromString("0.020000000"), + }, + }, + MevAgent: map[string]mevInfo{ + MevAgentJito: { + MevAgent: MevAgentJito, + MevAgentFee: decimal.RequireFromString("0.030000000"), + }, + }, + Swaps: []Swap{ + { + Program: SolProgramPump, + Event: TxEventBuy, + TxIndex: 7, + InstrIdx: 2, + InnerIdx: 1, + Pool: mustPubKey("11111111111111111111111111111111"), + BaseMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + QuoteMint: solana.WrappedSol, + BaseTokenProgram: solana.TokenProgramID, + QuoteTokenProgram: solana.TokenProgramID, + Creator: mustPubKey("BPFLoader1111111111111111111111111111111111"), + BaseMintDecimals: 6, + QuoteMintDecimals: 9, + User: mustPubKey("SysvarRent111111111111111111111111111111111"), + BaseAmount: decimal.NewFromInt(1200), + QuoteAmount: decimal.NewFromInt(3400), + SwapMode: SwapModeExactIn, + FixedAmount: decimal.NewFromInt(3400), + FixedAmountSide: SwapAmountSideQuote, + FixedMint: solana.WrappedSol, + LimitAmountType: SwapLimitTypeMinOut, + LimitAmount: decimal.NewFromInt(1000), + LimitAmountSide: SwapAmountSideBase, + LimitMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + ActualLimitAmount: decimal.NewFromInt(1200), + ActualLimitAmountSide: SwapAmountSideBase, + SlippageBps: decimal.RequireFromString("833.3333"), + BaseReserve: decimal.NewFromInt(5555), + QuoteReserve: decimal.NewFromInt(9999), + Mayhem: true, + Cashback: false, + UserBaseBalance: decimal.NewFromInt(777), + UserQuoteBalance: decimal.NewFromInt(888), + EntryContract: mustPubKey("ComputeBudget111111111111111111111111111111"), + MigrateToPool: mustPubKey("MemoSq4gqABAXKb96qnH8TysNcWxMyWCqXgDLGmfcHr"), + MigrateTopProgram: mustPubKey("AddressLookupTab1e1111111111111111111111111"), + LpMint: mustPubKey("4Nd1mJf8JQhRVTfJxW2YxXLNQKhPYo1JzN1u2KAPY1Hn"), + AfterSOLBalance: decimal.RequireFromString("0.321000000"), + ActiveBinId: 11, + FeeAmount: decimal.NewFromInt(99), + FeeBps: "123", + FeeSide: "base", + ConsumeUnit: 9999, + }, + }, + } + + encoded, err := EncodeTxBinary(original) + if err != nil { + t.Fatalf("EncodeTxBinary() error = %v", err) + } + + decoded, err := DecodeTxBinary(encoded) + if err != nil { + t.Fatalf("DecodeTxBinary() error = %v", err) + } + + if decoded.Signer != original.Signer { + t.Fatalf("Signer = %s, want %s", decoded.Signer, original.Signer) + } + if decoded.Block != original.Block { + t.Fatalf("Block = %d, want %d", decoded.Block, original.Block) + } + if decoded.BlockIndex != original.BlockIndex { + t.Fatalf("BlockIndex = %d, want %d", decoded.BlockIndex, original.BlockIndex) + } + if decoded.TxHash == nil { + t.Fatal("TxHash = nil, want non-nil") + } + if *decoded.TxHash != *original.TxHash { + t.Fatalf("TxHash mismatch") + } + if !decoded.CuFee.Equal(original.CuFee) { + t.Fatalf("CuFee = %s, want %s", decoded.CuFee, original.CuFee) + } + if !decoded.CUPrice.Equal(original.CUPrice) { + t.Fatalf("CUPrice = %s, want %s", decoded.CUPrice, original.CUPrice) + } + if decoded.BeforeSolBalance.StringFixed(9) != original.BeforeSolBalance.StringFixed(9) { + t.Fatalf("BeforeSolBalance = %s, want %s", decoded.BeforeSolBalance, original.BeforeSolBalance) + } + if decoded.AfterSOLBalance.StringFixed(9) != original.AfterSOLBalance.StringFixed(9) { + t.Fatalf("AfterSOLBalance = %s, want %s", decoded.AfterSOLBalance, original.AfterSOLBalance) + } + if decoded.CuLimit != original.CuLimit { + t.Fatalf("CuLimit = %d, want %d", decoded.CuLimit, original.CuLimit) + } + if decoded.ComputeUnitsConsumed != original.ComputeUnitsConsumed { + t.Fatalf("ComputeUnitsConsumed = %d, want %d", decoded.ComputeUnitsConsumed, original.ComputeUnitsConsumed) + } + if len(decoded.Platform) != len(original.Platform) { + t.Fatalf("Platform len = %d, want %d", len(decoded.Platform), len(original.Platform)) + } + if !decoded.Platform[PlatformGMGN].PlatformFee.Equal(original.Platform[PlatformGMGN].PlatformFee) { + t.Fatalf("Platform fee mismatch") + } + if len(decoded.MevAgent) != len(original.MevAgent) { + t.Fatalf("MevAgent len = %d, want %d", len(decoded.MevAgent), len(original.MevAgent)) + } + if !decoded.MevAgent[MevAgentJito].MevAgentFee.Equal(original.MevAgent[MevAgentJito].MevAgentFee) { + t.Fatalf("MevAgent fee mismatch") + } + if len(decoded.Swaps) != 1 { + t.Fatalf("Swaps len = %d, want 1", len(decoded.Swaps)) + } + + swap := decoded.Swaps[0] + if swap.Program != original.Swaps[0].Program { + t.Fatalf("swap.Program = %s, want %s", swap.Program, original.Swaps[0].Program) + } + if swap.Event != original.Swaps[0].Event { + t.Fatalf("swap.Event = %s, want %s", swap.Event, original.Swaps[0].Event) + } + if swap.TxIndex != original.Swaps[0].TxIndex { + t.Fatalf("swap.TxIndex = %d, want %d", swap.TxIndex, original.Swaps[0].TxIndex) + } + if !swap.BaseAmount.Equal(original.Swaps[0].BaseAmount) { + t.Fatalf("swap.BaseAmount = %s, want %s", swap.BaseAmount, original.Swaps[0].BaseAmount) + } + if !swap.QuoteAmount.Equal(original.Swaps[0].QuoteAmount) { + t.Fatalf("swap.QuoteAmount = %s, want %s", swap.QuoteAmount, original.Swaps[0].QuoteAmount) + } + if !swap.FixedAmount.Equal(original.Swaps[0].FixedAmount) { + t.Fatalf("swap.FixedAmount = %s, want %s", swap.FixedAmount, original.Swaps[0].FixedAmount) + } + if !swap.LimitAmount.Equal(original.Swaps[0].LimitAmount) { + t.Fatalf("swap.LimitAmount = %s, want %s", swap.LimitAmount, original.Swaps[0].LimitAmount) + } + if !swap.ActualLimitAmount.Equal(original.Swaps[0].ActualLimitAmount) { + t.Fatalf("swap.ActualLimitAmount = %s, want %s", swap.ActualLimitAmount, original.Swaps[0].ActualLimitAmount) + } + if swap.SlippageBps.String() != "833" { + t.Fatalf("swap.SlippageBps = %s, want 833", swap.SlippageBps) + } + if !swap.BaseReserve.Equal(original.Swaps[0].BaseReserve) { + t.Fatalf("swap.BaseReserve = %s, want %s", swap.BaseReserve, original.Swaps[0].BaseReserve) + } + if !swap.QuoteReserve.Equal(original.Swaps[0].QuoteReserve) { + t.Fatalf("swap.QuoteReserve = %s, want %s", swap.QuoteReserve, original.Swaps[0].QuoteReserve) + } + if !swap.UserBaseBalance.Equal(original.Swaps[0].UserBaseBalance) { + t.Fatalf("swap.UserBaseBalance = %s, want %s", swap.UserBaseBalance, original.Swaps[0].UserBaseBalance) + } + if !swap.UserQuoteBalance.Equal(original.Swaps[0].UserQuoteBalance) { + t.Fatalf("swap.UserQuoteBalance = %s, want %s", swap.UserQuoteBalance, original.Swaps[0].UserQuoteBalance) + } + if swap.AfterSOLBalance.StringFixed(9) != original.Swaps[0].AfterSOLBalance.StringFixed(9) { + t.Fatalf("swap.AfterSOLBalance = %s, want %s", swap.AfterSOLBalance, original.Swaps[0].AfterSOLBalance) + } + + if swap.ActiveBinId != 0 { + t.Fatalf("swap.ActiveBinId = %d, want 0", swap.ActiveBinId) + } + if !swap.FeeAmount.IsZero() { + t.Fatalf("swap.FeeAmount = %s, want 0", swap.FeeAmount) + } + if swap.FeeBps != "" { + t.Fatalf("swap.FeeBps = %q, want empty", swap.FeeBps) + } + if swap.FeeSide != "" { + t.Fatalf("swap.FeeSide = %q, want empty", swap.FeeSide) + } + if swap.ConsumeUnit != 0 { + t.Fatalf("swap.ConsumeUnit = %d, want 0", swap.ConsumeUnit) + } +} + +func TestTxBinaryRejectsUnknownProgramEnum(t *testing.T) { + txBinary := &TxBinary{ + SchemaVersion: txBinarySchemaVersionCurrent, + EnumVersion: txBinaryEnumVersionV1, + Swaps: []SwapBinary{ + {Program: "unknown_program"}, + }, + } + + if _, err := txBinary.MarshalBinary(); err == nil { + t.Fatal("MarshalBinary() error = nil, want error") + } +} + +func TestTxsBinaryRoundTripWithSharedAddressTable(t *testing.T) { + tx1 := Tx{ + Signer: mustPubKey("So11111111111111111111111111111111111111112"), + Block: 1, + BlockIndex: 1, + CuFee: decimal.NewFromInt(1000), + CUPrice: decimal.RequireFromString("0.123456"), + BeforeSolBalance: decimal.RequireFromString("1.000000000"), + AfterSOLBalance: decimal.RequireFromString("0.900000000"), + ComputeUnitsConsumed: 100, + CuLimit: 200000, + Swaps: []Swap{ + { + Program: SolProgramPump, + Event: TxEventBuy, + TxIndex: 1, + InstrIdx: 0, + InnerIdx: 0, + Pool: mustPubKey("11111111111111111111111111111111"), + BaseMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + QuoteMint: solana.WrappedSol, + BaseTokenProgram: solana.TokenProgramID, + QuoteTokenProgram: solana.TokenProgramID, + Creator: mustPubKey("BPFLoader1111111111111111111111111111111111"), + BaseMintDecimals: 6, + QuoteMintDecimals: 9, + User: mustPubKey("SysvarRent111111111111111111111111111111111"), + BaseAmount: decimal.NewFromInt(10), + QuoteAmount: decimal.NewFromInt(20), + SwapMode: SwapModeExactIn, + FixedAmount: decimal.NewFromInt(20), + FixedAmountSide: SwapAmountSideQuote, + FixedMint: solana.WrappedSol, + LimitAmountType: SwapLimitTypeMinOut, + LimitAmount: decimal.NewFromInt(9), + LimitAmountSide: SwapAmountSideBase, + LimitMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + ActualLimitAmount: decimal.NewFromInt(10), + ActualLimitAmountSide: SwapAmountSideBase, + SlippageBps: decimal.RequireFromString("100.2"), + BaseReserve: decimal.NewFromInt(100), + QuoteReserve: decimal.NewFromInt(200), + UserBaseBalance: decimal.NewFromInt(1), + UserQuoteBalance: decimal.NewFromInt(2), + EntryContract: solana.PublicKey{}, + MigrateToPool: solana.PublicKey{}, + MigrateTopProgram: solana.PublicKey{}, + LpMint: solana.PublicKey{}, + AfterSOLBalance: decimal.RequireFromString("0.800000000"), + }, + }, + } + tx2 := tx1 + tx2.Block = 2 + tx2.BlockIndex = 2 + tx2.CuFee = decimal.NewFromInt(2000) + tx2.AfterSOLBalance = decimal.RequireFromString("0.700000000") + tx2.Swaps = []Swap{tx1.Swaps[0]} + tx2.Swaps[0].TxIndex = 2 + tx2.Swaps[0].BaseAmount = decimal.NewFromInt(30) + tx2.Swaps[0].QuoteAmount = decimal.NewFromInt(40) + + batchEncoded, err := EncodeTxsBinary([]Tx{tx1, tx2}) + if err != nil { + t.Fatalf("EncodeTxsBinary() error = %v", err) + } + decoded, err := DecodeTxsBinary(batchEncoded) + if err != nil { + t.Fatalf("DecodeTxsBinary() error = %v", err) + } + if len(decoded) != 2 { + t.Fatalf("decoded len = %d, want 2", len(decoded)) + } + if decoded[0].Signer != tx1.Signer || decoded[1].Signer != tx2.Signer { + t.Fatalf("decoded signer mismatch") + } + if decoded[0].Swaps[0].Pool != tx1.Swaps[0].Pool || decoded[1].Swaps[0].Pool != tx2.Swaps[0].Pool { + t.Fatalf("decoded shared address mismatch") + } + + single1, err := EncodeTxBinary(&tx1) + if err != nil { + t.Fatalf("EncodeTxBinary(tx1) error = %v", err) + } + single2, err := EncodeTxBinary(&tx2) + if err != nil { + t.Fatalf("EncodeTxBinary(tx2) error = %v", err) + } + if len(batchEncoded) >= len(single1)+len(single2) { + t.Fatalf("batch encoded = %d, want smaller than singles sum %d", len(batchEncoded), len(single1)+len(single2)) + } +} + +func TestDecodeTxsBinaryReader(t *testing.T) { + tx1 := Tx{ + Signer: mustPubKey("So11111111111111111111111111111111111111112"), + Block: 100, + BlockIndex: 7, + CuFee: decimal.NewFromInt(111), + CUPrice: decimal.RequireFromString("0.123456"), + BeforeSolBalance: decimal.RequireFromString("1.000000000"), + AfterSOLBalance: decimal.RequireFromString("0.500000000"), + ComputeUnitsConsumed: 1234, + CuLimit: 250000, + Swaps: []Swap{ + { + Program: SolProgramPump, + Event: TxEventBuy, + TxIndex: 3, + InstrIdx: 1, + InnerIdx: 2, + Pool: mustPubKey("11111111111111111111111111111111"), + BaseMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + QuoteMint: solana.WrappedSol, + BaseTokenProgram: solana.TokenProgramID, + QuoteTokenProgram: solana.TokenProgramID, + Creator: mustPubKey("BPFLoader1111111111111111111111111111111111"), + BaseMintDecimals: 6, + QuoteMintDecimals: 9, + User: mustPubKey("SysvarRent111111111111111111111111111111111"), + BaseAmount: decimal.NewFromInt(100), + QuoteAmount: decimal.NewFromInt(200), + SwapMode: SwapModeExactIn, + FixedAmount: decimal.NewFromInt(200), + FixedAmountSide: SwapAmountSideQuote, + FixedMint: solana.WrappedSol, + LimitAmountType: SwapLimitTypeMinOut, + LimitAmount: decimal.NewFromInt(90), + LimitAmountSide: SwapAmountSideBase, + LimitMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + ActualLimitAmount: decimal.NewFromInt(100), + ActualLimitAmountSide: SwapAmountSideBase, + SlippageBps: decimal.RequireFromString("99.6"), + BaseReserve: decimal.NewFromInt(1000), + QuoteReserve: decimal.NewFromInt(2000), + UserBaseBalance: decimal.NewFromInt(10), + UserQuoteBalance: decimal.NewFromInt(20), + AfterSOLBalance: decimal.RequireFromString("0.400000000"), + }, + }, + } + tx2 := tx1 + tx2.Block = 101 + tx2.BlockIndex = 8 + tx2.CuFee = decimal.NewFromInt(222) + tx2.AfterSOLBalance = decimal.RequireFromString("0.300000000") + tx2.Swaps = []Swap{tx1.Swaps[0]} + tx2.Swaps[0].TxIndex = 4 + tx2.Swaps[0].BaseAmount = decimal.NewFromInt(300) + + encoded, err := EncodeTxsBinary([]Tx{tx1, tx2}) + if err != nil { + t.Fatalf("EncodeTxsBinary() error = %v", err) + } + + var decoded []*Tx + for tx, err := range DecodeTxsBinaryReader(bytes.NewReader(encoded)) { + if err != nil { + t.Fatalf("DecodeTxsBinaryReader() error = %v", err) + } + decoded = append(decoded, tx) + } + + if len(decoded) != 2 { + t.Fatalf("decoded len = %d, want 2", len(decoded)) + } + if decoded[0].Signer != tx1.Signer || decoded[1].Signer != tx2.Signer { + t.Fatalf("decoded signer mismatch") + } + if decoded[0].Block != tx1.Block || decoded[1].Block != tx2.Block { + t.Fatalf("decoded block mismatch") + } + if decoded[0].Swaps[0].BaseAmount.Cmp(tx1.Swaps[0].BaseAmount) != 0 { + t.Fatalf("decoded tx1 swap base amount = %s, want %s", decoded[0].Swaps[0].BaseAmount, tx1.Swaps[0].BaseAmount) + } + if decoded[1].Swaps[0].BaseAmount.Cmp(tx2.Swaps[0].BaseAmount) != 0 { + t.Fatalf("decoded tx2 swap base amount = %s, want %s", decoded[1].Swaps[0].BaseAmount, tx2.Swaps[0].BaseAmount) + } +} + +func TestDecodeTxsBinaryReaderEarlyStop(t *testing.T) { + tx := Tx{ + Signer: mustPubKey("So11111111111111111111111111111111111111112"), + Block: 1, + BlockIndex: 1, + CuFee: decimal.NewFromInt(1), + CUPrice: decimal.RequireFromString("0.000001"), + BeforeSolBalance: decimal.RequireFromString("1.000000000"), + AfterSOLBalance: decimal.RequireFromString("0.999999999"), + ComputeUnitsConsumed: 1, + CuLimit: 1, + } + encoded, err := EncodeTxsBinary([]Tx{tx, tx, tx}) + if err != nil { + t.Fatalf("EncodeTxsBinary() error = %v", err) + } + + count := 0 + for decodedTx, err := range DecodeTxsBinaryReader(bytes.NewReader(encoded)) { + if err != nil { + t.Fatalf("DecodeTxsBinaryReader() error = %v", err) + } + if decodedTx == nil { + t.Fatal("decoded tx is nil") + } + count++ + break + } + + if count != 1 { + t.Fatalf("count = %d, want 1", count) + } +} + +func TestMergeTxsBinaryBytes(t *testing.T) { + tx1 := Tx{ + Signer: mustPubKey("So11111111111111111111111111111111111111112"), + Block: 11, + BlockIndex: 1, + CuFee: decimal.NewFromInt(10), + CUPrice: decimal.RequireFromString("0.000123"), + BeforeSolBalance: decimal.RequireFromString("1.100000000"), + AfterSOLBalance: decimal.RequireFromString("1.000000000"), + ComputeUnitsConsumed: 10, + CuLimit: 100, + Swaps: []Swap{ + { + Program: SolProgramPump, + Event: TxEventBuy, + TxIndex: 1, + Pool: mustPubKey("11111111111111111111111111111111"), + BaseMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + QuoteMint: solana.WrappedSol, + BaseTokenProgram: solana.TokenProgramID, + QuoteTokenProgram: solana.TokenProgramID, + Creator: mustPubKey("BPFLoader1111111111111111111111111111111111"), + User: mustPubKey("SysvarRent111111111111111111111111111111111"), + FixedMint: solana.WrappedSol, + LimitMint: mustPubKey("3wyAj7RtG72wM1Wv9DkYfL7RAx9X3Jx1sC6E6mN4jWeL"), + EntryContract: solana.PublicKey{}, + MigrateToPool: solana.PublicKey{}, + MigrateTopProgram: solana.PublicKey{}, + LpMint: solana.PublicKey{}, + }, + }, + } + tx2 := Tx{ + Signer: mustPubKey("SysvarRent111111111111111111111111111111111"), + Block: 12, + BlockIndex: 2, + CuFee: decimal.NewFromInt(20), + CUPrice: decimal.RequireFromString("0.000456"), + BeforeSolBalance: decimal.RequireFromString("2.200000000"), + AfterSOLBalance: decimal.RequireFromString("2.000000000"), + ComputeUnitsConsumed: 20, + CuLimit: 200, + Swaps: []Swap{ + { + Program: SolProgramPump, + Event: TxEventSell, + TxIndex: 2, + Pool: mustPubKey("MemoSq4gqABAXKb96qnH8TysNcWxMyWCqXgDLGmfcHr"), + BaseMint: mustPubKey("4Nd1mJf8JQhRVTfJxW2YxXLNQKhPYo1JzN1u2KAPY1Hn"), + QuoteMint: solana.WrappedSol, + BaseTokenProgram: solana.TokenProgramID, + QuoteTokenProgram: solana.TokenProgramID, + Creator: mustPubKey("ComputeBudget111111111111111111111111111111"), + User: mustPubKey("So11111111111111111111111111111111111111112"), + FixedMint: solana.WrappedSol, + LimitMint: mustPubKey("4Nd1mJf8JQhRVTfJxW2YxXLNQKhPYo1JzN1u2KAPY1Hn"), + EntryContract: solana.PublicKey{}, + MigrateToPool: solana.PublicKey{}, + MigrateTopProgram: solana.PublicKey{}, + LpMint: solana.PublicKey{}, + }, + }, + } + + batch1, err := EncodeTxsBinary([]Tx{tx1}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch1) error = %v", err) + } + batch2, err := EncodeTxsBinary([]Tx{tx2}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch2) error = %v", err) + } + + merged, err := MergeTxsBinaryBytes([][]byte{batch1, batch2}) + if err != nil { + t.Fatalf("MergeTxsBinaryBytes() error = %v", err) + } + + var mergedBinary TxsBinary + if err := mergedBinary.UnmarshalBinary(merged); err != nil { + t.Fatalf("UnmarshalBinary(merged) error = %v", err) + } + if len(mergedBinary.Txs) != 2 { + t.Fatalf("merged tx count = %d, want 2", len(mergedBinary.Txs)) + } + if len(mergedBinary.AddressTable) >= len(mustTxBinary(t, batch1).AddressTable)+len(mustTxBinary(t, batch2).AddressTable) { + t.Fatalf("merged address table was not deduplicated") + } + + decoded, err := DecodeTxsBinary(merged) + if err != nil { + t.Fatalf("DecodeTxsBinary(merged) error = %v", err) + } + if len(decoded) != 2 { + t.Fatalf("decoded len = %d, want 2", len(decoded)) + } + if decoded[0].Block != tx1.Block || decoded[1].Block != tx2.Block { + t.Fatalf("decoded block mismatch") + } +} + +func TestMergeTxsBinarySourcesToWriterWithConcatenatedBatches(t *testing.T) { + tx1 := Tx{ + Signer: mustPubKey("So11111111111111111111111111111111111111112"), + Block: 21, + BlockIndex: 1, + CuFee: decimal.NewFromInt(1), + CUPrice: decimal.RequireFromString("0.000001"), + BeforeSolBalance: decimal.RequireFromString("1.000000000"), + AfterSOLBalance: decimal.RequireFromString("0.900000000"), + ComputeUnitsConsumed: 11, + CuLimit: 111, + } + tx2 := tx1 + tx2.Block = 22 + tx2.BlockIndex = 2 + tx2.Signer = mustPubKey("SysvarRent111111111111111111111111111111111") + tx3 := tx1 + tx3.Block = 23 + tx3.BlockIndex = 3 + tx3.Signer = mustPubKey("ComputeBudget111111111111111111111111111111") + + batch1, err := EncodeTxsBinary([]Tx{tx1}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch1) error = %v", err) + } + batch2, err := EncodeTxsBinary([]Tx{tx2}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch2) error = %v", err) + } + batch3, err := EncodeTxsBinary([]Tx{tx3}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch3) error = %v", err) + } + + source1 := &testTxsBinarySource{ + data: append(append([]byte{}, batch1...), batch2...), + } + source2 := &testTxsBinarySource{ + data: batch3, + } + + var out bytes.Buffer + if err := MergeTxsBinarySourcesToWriter([]TxsBinaryReaderSource{source1, source2}, &out); err != nil { + t.Fatalf("MergeTxsBinarySourcesToWriter() error = %v", err) + } + + if source1.opens != 2 || source2.opens != 2 { + t.Fatalf("source opens = (%d, %d), want (2, 2)", source1.opens, source2.opens) + } + + decoded, err := DecodeTxsBinary(out.Bytes()) + if err != nil { + t.Fatalf("DecodeTxsBinary(merged) error = %v", err) + } + if len(decoded) != 3 { + t.Fatalf("decoded len = %d, want 3", len(decoded)) + } + if decoded[0].Block != tx1.Block || decoded[1].Block != tx2.Block || decoded[2].Block != tx3.Block { + t.Fatalf("decoded block order mismatch") + } +} + +func mustPubKey(value string) solana.PublicKey { + return solana.MustPublicKeyFromBase58(value) +} + +func mustTxBinary(t *testing.T, data []byte) *TxsBinary { + t.Helper() + + var txsBinary TxsBinary + if err := txsBinary.UnmarshalBinary(data); err != nil { + t.Fatalf("UnmarshalBinary() error = %v", err) + } + return &txsBinary +} + +type testTxsBinarySource struct { + data []byte + opens int +} + +func (s *testTxsBinarySource) OpenTxsBinaryReader() (io.ReadCloser, error) { + s.opens++ + return io.NopCloser(bytes.NewReader(s.data)), nil +}