diff --git a/tx_binary.go b/tx_binary.go index c5b8c61..3620967 100644 --- a/tx_binary.go +++ b/tx_binary.go @@ -1,6 +1,7 @@ package pump_parser import ( + "bufio" "bytes" "encoding/binary" "fmt" @@ -107,6 +108,18 @@ type TxsBinaryReaderSource interface { OpenTxsBinaryReader() (io.ReadCloser, error) } +type TxsBinaryBatchHeaderContext struct { + SourceIndex int + BatchIndex int + Reader *bufio.Reader +} + +type TxsBinaryBatchHeaderFunc func(ctx *TxsBinaryBatchHeaderContext) (skip bool, err error) + +type TxsBinaryMergeOptions struct { + BatchHeaderFunc TxsBinaryBatchHeaderFunc +} + type PlatformBinary struct { Platform string PlatformFee uint64 @@ -307,24 +320,32 @@ func DecodeTxsBinaryReader(r io.Reader) iter.Seq2[*Tx, error] { } func MergeTxsBinaryBytes(encodedBatches [][]byte) ([]byte, error) { + return MergeTxsBinaryBytesWithOptions(encodedBatches, TxsBinaryMergeOptions{}) +} + +func MergeTxsBinaryBytesWithOptions(encodedBatches [][]byte, opts TxsBinaryMergeOptions) ([]byte, error) { sources := make([]TxsBinaryReaderSource, 0, len(encodedBatches)) for _, encoded := range encodedBatches { sources = append(sources, txBinaryBytesSource{data: encoded}) } var out bytes.Buffer - if err := MergeTxsBinarySourcesToWriter(sources, &out); err != nil { + if err := MergeTxsBinarySourcesToWriterWithOptions(sources, &out, opts); err != nil { return nil, err } return out.Bytes(), nil } func MergeTxsBinarySourcesToWriter(sources []TxsBinaryReaderSource, w io.Writer) error { + return MergeTxsBinarySourcesToWriterWithOptions(sources, w, TxsBinaryMergeOptions{}) +} + +func MergeTxsBinarySourcesToWriterWithOptions(sources []TxsBinaryReaderSource, w io.Writer, opts TxsBinaryMergeOptions) error { if w == nil { return fmt.Errorf("txs binary writer is nil") } - plan, err := txBinaryBuildMergePlan(sources) + plan, err := txBinaryBuildMergePlan(sources, opts) if err != nil { return err } @@ -343,9 +364,22 @@ func MergeTxsBinarySourcesToWriter(sources []TxsBinaryReaderSource, w io.Writer) return fmt.Errorf("source[%d]: open reader: %w", sourceIndex, err) } - dec := txBinaryStreamDecoder{reader: reader} + bufferedReader := bufio.NewReader(reader) + dec := txBinaryStreamDecoder{reader: bufferedReader} batchIndex := 0 for { + skipBatch, err := txBinaryApplyMergeBatchHeader(bufferedReader, opts, sourceIndex, batchIndex) + if err != nil { + closeErr := reader.Close() + if err == io.EOF { + if closeErr != nil { + return fmt.Errorf("source[%d]: close reader: %w", sourceIndex, closeErr) + } + break + } + return fmt.Errorf("source[%d].batch[%d]: %w", sourceIndex, batchIndex, err) + } + header, err := dec.readTxsBinaryHeaderOrEOF() if err != nil { closeErr := reader.Close() @@ -368,6 +402,9 @@ func MergeTxsBinarySourcesToWriter(sources []TxsBinaryReaderSource, w io.Writer) reader.Close() return fmt.Errorf("source[%d].batch[%d].tx[%d]: %w", sourceIndex, batchIndex, txIndex, err) } + if skipBatch { + continue + } if err := txBinaryRemapTxAddressTable(&tx, header.addressTable, plan.addressTable, plan.addressIndex); err != nil { reader.Close() return fmt.Errorf("source[%d].batch[%d].tx[%d]: %w", sourceIndex, batchIndex, txIndex, err) @@ -1780,7 +1817,7 @@ func txBinaryReadTxBody(dec txBinaryBodyReader, tx *TxBinary, enumTable *txBinar return nil } -func txBinaryBuildMergePlan(sources []TxsBinaryReaderSource) (*txsBinaryMergePlan, error) { +func txBinaryBuildMergePlan(sources []TxsBinaryReaderSource, opts TxsBinaryMergeOptions) (*txsBinaryMergePlan, error) { if len(sources) == 0 { return nil, fmt.Errorf("txs binary sources are empty") } @@ -1801,9 +1838,22 @@ func txBinaryBuildMergePlan(sources []TxsBinaryReaderSource) (*txsBinaryMergePla return nil, fmt.Errorf("source[%d]: open reader: %w", sourceIndex, err) } - dec := txBinaryStreamDecoder{reader: reader} + bufferedReader := bufio.NewReader(reader) + dec := txBinaryStreamDecoder{reader: bufferedReader} batchIndex := 0 for { + skipBatch, err := txBinaryApplyMergeBatchHeader(bufferedReader, opts, sourceIndex, batchIndex) + if err != nil { + closeErr := reader.Close() + if err == io.EOF { + if closeErr != nil { + return nil, fmt.Errorf("source[%d]: close reader: %w", sourceIndex, closeErr) + } + break + } + return nil, fmt.Errorf("source[%d].batch[%d]: %w", sourceIndex, batchIndex, err) + } + header, err := dec.readTxsBinaryHeaderOrEOF() if err != nil { closeErr := reader.Close() @@ -1833,17 +1883,21 @@ func txBinaryBuildMergePlan(sources []TxsBinaryReaderSource) (*txsBinaryMergePla } for addressIndex, address := range header.addressTable { - if err := builder.add(address); err != nil { - reader.Close() - return nil, fmt.Errorf("source[%d].batch[%d].address[%d]: %w", sourceIndex, batchIndex, addressIndex, err) + if !skipBatch { + if err := builder.add(address); err != nil { + reader.Close() + return nil, fmt.Errorf("source[%d].batch[%d].address[%d]: %w", sourceIndex, batchIndex, addressIndex, err) + } } } - if uint64(plan.txCount)+uint64(header.count) > uint64(math.MaxUint32) { - reader.Close() - return nil, fmt.Errorf("merged tx count exceeds uint32 capacity") + if !skipBatch { + if uint64(plan.txCount)+uint64(header.count) > uint64(math.MaxUint32) { + reader.Close() + return nil, fmt.Errorf("merged tx count exceeds uint32 capacity") + } + plan.txCount += header.count } - plan.txCount += header.count for txIndex := uint32(0); txIndex < header.count; txIndex++ { tx := TxBinary{ @@ -1947,6 +2001,17 @@ func txBinaryWriteAll(w io.Writer, data []byte) error { return nil } +func txBinaryApplyMergeBatchHeader(reader *bufio.Reader, opts TxsBinaryMergeOptions, sourceIndex int, batchIndex int) (bool, error) { + if opts.BatchHeaderFunc == nil { + return false, nil + } + return opts.BatchHeaderFunc(&TxsBinaryBatchHeaderContext{ + SourceIndex: sourceIndex, + BatchIndex: batchIndex, + Reader: reader, + }) +} + type txBinaryEnumTable struct { version uint16 programs txBinaryEnumSet diff --git a/tx_binary_test.go b/tx_binary_test.go index fb69dfd..1fb71e9 100644 --- a/tx_binary_test.go +++ b/tx_binary_test.go @@ -602,6 +602,89 @@ func TestMergeTxsBinarySourcesToWriterWithConcatenatedBatches(t *testing.T) { } } +func TestMergeTxsBinarySourcesToWriterWithBatchHeaderFuncSkip(t *testing.T) { + tx1 := Tx{ + Signer: mustPubKey("So11111111111111111111111111111111111111112"), + Block: 31, + BlockIndex: 1, + CuFee: decimal.NewFromInt(1), + CUPrice: decimal.RequireFromString("0.000001"), + BeforeSolBalance: decimal.RequireFromString("1.000000000"), + AfterSOLBalance: decimal.RequireFromString("0.900000000"), + ComputeUnitsConsumed: 11, + CuLimit: 111, + } + tx2 := tx1 + tx2.Block = 32 + tx2.BlockIndex = 2 + tx2.Signer = mustPubKey("SysvarRent111111111111111111111111111111111") + tx3 := tx1 + tx3.Block = 33 + tx3.BlockIndex = 3 + tx3.Signer = mustPubKey("ComputeBudget111111111111111111111111111111") + + batch1, err := EncodeTxsBinary([]Tx{tx1}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch1) error = %v", err) + } + batch2, err := EncodeTxsBinary([]Tx{tx2}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch2) error = %v", err) + } + batch3, err := EncodeTxsBinary([]Tx{tx3}) + if err != nil { + t.Fatalf("EncodeTxsBinary(batch3) error = %v", err) + } + + source := &testTxsBinarySource{ + data: append( + append( + append([]byte{}, testBatchHeader(false)...), + batch1..., + ), + append( + append(testBatchHeader(true), batch2...), + append(testBatchHeader(false), batch3...)..., + )..., + ), + } + + var out bytes.Buffer + err = MergeTxsBinarySourcesToWriterWithOptions( + []TxsBinaryReaderSource{source}, + &out, + TxsBinaryMergeOptions{ + BatchHeaderFunc: func(ctx *TxsBinaryBatchHeaderContext) (bool, error) { + header := make([]byte, 5) + if _, err := io.ReadFull(ctx.Reader, header); err != nil { + return false, err + } + if !bytes.Equal(header[:4], []byte("BHDR")) { + return false, io.ErrUnexpectedEOF + } + return header[4] == 1, nil + }, + }, + ) + if err != nil { + t.Fatalf("MergeTxsBinarySourcesToWriterWithOptions() error = %v", err) + } + + decoded, err := DecodeTxsBinary(out.Bytes()) + if err != nil { + t.Fatalf("DecodeTxsBinary(merged) error = %v", err) + } + if len(decoded) != 2 { + t.Fatalf("decoded len = %d, want 2", len(decoded)) + } + if decoded[0].Block != tx1.Block || decoded[1].Block != tx3.Block { + t.Fatalf("decoded block order mismatch after skip") + } + if source.opens != 2 { + t.Fatalf("source.opens = %d, want 2", source.opens) + } +} + func mustPubKey(value string) solana.PublicKey { return solana.MustPublicKeyFromBase58(value) } @@ -625,3 +708,11 @@ func (s *testTxsBinarySource) OpenTxsBinaryReader() (io.ReadCloser, error) { s.opens++ return io.NopCloser(bytes.NewReader(s.data)), nil } + +func testBatchHeader(skip bool) []byte { + header := []byte("BHDR\x00") + if skip { + header[4] = 1 + } + return header +}