diff --git a/.gitignore b/.gitignore index ba8d0fb..38c7121 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,4 @@ data/ # Build artifacts *.tar.gz *.zip +.amp_state \ No newline at end of file diff --git a/.test.env b/.test.env new file mode 100644 index 0000000..e69de29 diff --git a/Dockerfile.kafka b/Dockerfile.kafka new file mode 100644 index 0000000..2b902ac --- /dev/null +++ b/Dockerfile.kafka @@ -0,0 +1,23 @@ +FROM python:3.12-slim + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv + +WORKDIR /app + +COPY pyproject.toml README.md ./ +COPY src/ ./src/ +COPY apps/ ./apps/ + +RUN uv pip install --system --no-cache . && \ + uv pip install --system --no-cache kafka-python lmdb + +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 + +ENTRYPOINT ["python", "apps/kafka_streaming_loader.py"] +CMD ["--help"] diff --git a/apps/kafka_consumer.py b/apps/kafka_consumer.py new file mode 100755 index 0000000..5fabdd1 --- /dev/null +++ b/apps/kafka_consumer.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +"""Simple Kafka consumer script to print messages from a topic in real-time. + +Messages are consumed from a consumer group, so subsequent runs will only show +new messages. Press Ctrl+C to exit cleanly. + +Usage: + python kafka_consumer.py [topic] [broker] [group_id] + +Examples: + python kafka_consumer.py + python kafka_consumer.py anvil_logs + python kafka_consumer.py anvil_logs localhost:9092 + python kafka_consumer.py anvil_logs localhost:9092 my-group +""" + +import json +import sys +from datetime import datetime + +from kafka import KafkaConsumer + +topic = sys.argv[1] if len(sys.argv) > 1 else 'anvil_logs' +broker = sys.argv[2] if len(sys.argv) > 2 else 'localhost:9092' +group_id = sys.argv[3] if len(sys.argv) > 3 else 'kafka-consumer-cli' + +print(f'Consuming from: {broker} -> topic: {topic}') +print(f'Consumer group: {group_id}') +print(f'Started at: {datetime.now().strftime("%H:%M:%S")}') +print('-' * 80) + +consumer = KafkaConsumer( + topic, + bootstrap_servers=broker, + group_id=group_id, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), +) + +msg_count = 0 +data_count = 0 +reorg_count = 0 + +try: + for msg in consumer: + msg_count += 1 + msg_type = msg.value.get('_type', 'unknown') + + if msg_type == 'data': + data_count += 1 + print(f'\nMessage #{msg_count} [DATA] - Key: {msg.key.decode() if msg.key else "None"}') + print(f'Offset: {msg.offset} | Partition: {msg.partition}') + + for k, v in msg.value.items(): + if k != '_type': + print(f'{k}: {v}') + + elif msg_type == 'reorg': + reorg_count += 1 + print(f'\nMessage #{msg_count} [REORG] - Key: {msg.key.decode() if msg.key else "None"}') + print(f'Network: {msg.value.get("network")}') + print(f'Blocks: {msg.value.get("start_block")} -> {msg.value.get("end_block")}') + + else: + print(f'\nMessage #{msg_count} [UNKNOWN]') + print(json.dumps(msg.value, indent=2)) + + print(f'\nTotal: {msg_count} msgs | Data: {data_count} | Reorgs: {reorg_count}') + print('-' * 80) + +except KeyboardInterrupt: + print('\n\nStopped') +finally: + consumer.close() diff --git a/apps/kafka_streaming_loader.py b/apps/kafka_streaming_loader.py new file mode 100644 index 0000000..439a616 --- /dev/null +++ b/apps/kafka_streaming_loader.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +"""Stream data to Kafka with resume watermark support.""" + +import argparse +import json +import logging +import os +import time +from pathlib import Path + +from amp.client import Client +from amp.loaders.types import LabelJoinConfig +from amp.streaming import BlockRange, ResumeWatermark + +logger = logging.getLogger('amp.kafka_streaming_loader') + +RETRYABLE_ERRORS = ( + ConnectionError, + TimeoutError, + OSError, +) + + +def retry_with_backoff(func, max_retries=5, initial_delay=1.0, max_delay=60.0, backoff_factor=2.0): + """Execute function with exponential backoff retry on transient errors.""" + delay = initial_delay + last_exception = None + + for attempt in range(max_retries + 1): + try: + return func() + except RETRYABLE_ERRORS as e: + last_exception = e + if attempt == max_retries: + logger.error(f'Max retries ({max_retries}) exceeded: {e}') + raise + logger.warning(f'Attempt {attempt + 1} failed: {e}. Retrying in {delay:.1f}s...') + time.sleep(delay) + delay = min(delay * backoff_factor, max_delay) + + raise last_exception + + +def get_block_hash(client: Client, raw_dataset: str, block_num: int) -> str: + """Get block hash from dataset.blocks table.""" + query = f'SELECT hash FROM "{raw_dataset}".blocks WHERE block_num = {block_num} LIMIT 1' + result = client.get_sql(query, read_all=True) + hash_val = result.to_pydict()['hash'][0] + return '0x' + hash_val.hex() if isinstance(hash_val, bytes) else hash_val + + +def get_latest_block(client: Client, raw_dataset: str) -> int: + """Get latest block number from dataset.blocks table.""" + query = f'SELECT block_num FROM "{raw_dataset}".blocks ORDER BY block_num DESC LIMIT 1' + logger.debug(f'Fetching latest block from {raw_dataset}') + logger.debug(f'Query: {query}') + result = client.get_sql(query, read_all=True) + block_num = result.to_pydict()['block_num'][0] + logger.info(f'Latest block in {raw_dataset}: {block_num}') + return block_num + + +def create_watermark(client: Client, raw_dataset: str, network: str, start_block: int) -> ResumeWatermark: + """Create a resume watermark for the given start block.""" + watermark_block = start_block - 1 + watermark_hash = get_block_hash(client, raw_dataset, watermark_block) + return ResumeWatermark( + ranges=[BlockRange(network=network, start=watermark_block, end=watermark_block, hash=watermark_hash)] + ) + + +def main( + amp_server: str, + kafka_brokers: str, + topic: str, + query_file: str, + raw_dataset: str, + network: str, + start_block: str = None, + label_csv: str = None, + state_dir: str = '.amp_state', + auth: bool = False, + auth_token: str = None, + max_retries: int = 5, + retry_delay: float = 1.0, + kafka_config: dict = None, + reorg_topic: str = None, +): + def connect(): + return Client(amp_server, auth=auth, auth_token=auth_token) + + client = retry_with_backoff(connect, max_retries=max_retries, initial_delay=retry_delay) + logger.info(f'Connected to {amp_server}') + + if label_csv and Path(label_csv).exists(): + client.configure_label('tokens', label_csv) + logger.info(f'Loaded {len(client.label_manager.get_label("tokens"))} labels from {label_csv}') + label_config = LabelJoinConfig( + label_name='tokens', label_key_column='token_address', stream_key_column='token_address' + ) + else: + label_config = None + + connection_config = { + 'bootstrap_servers': kafka_brokers, + 'client_id': 'amp-kafka-loader', + 'state': {'enabled': True, 'storage': 'lmdb', 'data_dir': state_dir}, + } + if reorg_topic: + connection_config['reorg_topic'] = reorg_topic + if kafka_config: + connection_config.update(kafka_config) + client.configure_connection('kafka', 'kafka', connection_config) + + with open(query_file) as f: + query = f.read() + + if start_block == 'latest': + block = get_latest_block(client, raw_dataset) + resume_watermark = create_watermark(client, raw_dataset, network, block) + logger.info(f'Starting from latest block {block}') + elif start_block is not None: + block = int(start_block) + resume_watermark = create_watermark(client, raw_dataset, network, block) if block > 0 else None + logger.info(f'Starting from block {block}') + else: + resume_watermark = None + logger.info('Resuming from LMDB state') + logger.info(f'Streaming to Kafka: {kafka_brokers} -> {topic}') + + batch_count = 0 + + def stream_batches(): + nonlocal batch_count + for result in client.sql(query).load( + 'kafka', topic, stream=True, label_config=label_config, resume_watermark=resume_watermark + ): + if result.success: + batch_count += 1 + block_info = '' + if result.metadata and result.metadata.get('block_ranges'): + ranges = result.metadata['block_ranges'] + parts = [f'{r["network"]}:{r["start"]}-{r["end"]}' for r in ranges] + block_info = f' [{", ".join(parts)}]' + logger.info(f'Batch {batch_count}: {result.rows_loaded} rows in {result.duration:.2f}s{block_info}') + else: + logger.error(f'Batch error: {result.error}') + + retry_with_backoff(stream_batches, max_retries=max_retries, initial_delay=retry_delay) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Stream data to Kafka with resume watermark') + parser.add_argument('--amp-server', default=os.getenv('AMP_SERVER_URL', 'grpc://127.0.0.1:1602')) + parser.add_argument('--kafka-brokers', default='localhost:9092') + parser.add_argument('--topic', required=True) + parser.add_argument('--reorg-topic', help='Separate topic for reorg messages (default: same as --topic)') + parser.add_argument('--query-file', required=True) + parser.add_argument( + '--raw-dataset', required=True, help='Dataset name for the raw dataset of the chain (e.g., anvil, eth_firehose)' + ) + parser.add_argument('--network', default='anvil') + parser.add_argument('--start-block', type=str, help='Start from specific block number or "latest"') + parser.add_argument('--label-csv', help='Optional CSV for label joining') + parser.add_argument('--state-dir', default='.amp_state', help='Directory for LMDB state storage') + parser.add_argument('--auth', action='store_true', help='Enable auth using ~/.amp/cache or AMP_AUTH_TOKEN env var') + parser.add_argument('--auth-token', help='Explicit auth token (works independently, does not require --auth)') + parser.add_argument('--max-retries', type=int, default=5, help='Max retries for connection failures (default: 5)') + parser.add_argument('--retry-delay', type=float, default=1.0, help='Initial retry delay in seconds (default: 1.0)') + parser.add_argument( + '--kafka-config', + type=str, + help='Extra Kafka producer config as JSON. Uses kafka-python naming (underscores). ' + 'Example: \'{"compression_type": "lz4", "linger_ms": 5}\'. ' + 'See: https://kafka-python.readthedocs.io/en/master/apidoc/KafkaProducer.html', + ) + parser.add_argument( + '--kafka-config-file', + type=Path, + help='Path to JSON file with extra Kafka producer config', + ) + parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR']) + args = parser.parse_args() + + logging.basicConfig(level=logging.WARNING, format='%(asctime)s [%(name)s] %(levelname)s: %(message)s') + log_level = getattr(logging, args.log_level) if args.log_level else logging.INFO + logging.getLogger('amp').setLevel(log_level) + + kafka_config = {} + if args.kafka_config_file: + kafka_config = json.loads(args.kafka_config_file.read_text()) + logger.info(f'Loaded Kafka config from {args.kafka_config_file}') + if args.kafka_config: + kafka_config.update(json.loads(args.kafka_config)) + + try: + main( + amp_server=args.amp_server, + kafka_brokers=args.kafka_brokers, + topic=args.topic, + query_file=args.query_file, + raw_dataset=args.raw_dataset, + network=args.network, + start_block=args.start_block, + label_csv=args.label_csv, + state_dir=args.state_dir, + auth=args.auth, + auth_token=args.auth_token, + max_retries=args.max_retries, + retry_delay=args.retry_delay, + kafka_config=kafka_config or None, + reorg_topic=args.reorg_topic, + ) + except KeyboardInterrupt: + logger.info('Stopped by user') + except Exception as e: + logger.error(f'Fatal error: {e}') + raise diff --git a/apps/kafka_streaming_loader_guide.md b/apps/kafka_streaming_loader_guide.md new file mode 100644 index 0000000..418eae7 --- /dev/null +++ b/apps/kafka_streaming_loader_guide.md @@ -0,0 +1,262 @@ +# Kafka Streaming Loader - Usage Guide + +Stream blockchain data to Kafka topics in real-time. + +## Quick Start + +```bash +uv run python apps/kafka_streaming_loader.py \ + --amp-server 'grpc+tls://gateway.amp.staging.thegraph.com:443' \ + --kafka-brokers localhost:9092 \ + --topic erc20_transfers \ + --query-file apps/queries/erc20_transfers_activity.sql \ + --raw-dataset 'edgeandnode/ethereum_mainnet' \ + --network ethereum-mainnet +``` + +## Basic Usage + +### Minimal Example (Staging Gateway) + +```bash +uv run python apps/kafka_streaming_loader.py \ + --amp-server 'grpc+tls://gateway.amp.staging.thegraph.com:443' \ + --kafka-brokers localhost:9092 \ + --topic my_topic \ + --query-file my_query.sql \ + --raw-dataset 'edgeandnode/ethereum_mainnet' \ + --network ethereum-mainnet +``` + +### Local Development (Anvil) + +```bash +uv run python apps/kafka_streaming_loader.py \ + --topic anvil_logs \ + --query-file apps/queries/anvil_logs.sql \ + --raw-dataset anvil \ + --start-block 0 +``` + +## Configuration Options + +### Required Arguments + +| Argument | Description | +|----------|-------------| +| `--topic NAME` | Kafka topic name | +| `--query-file PATH` | Path to SQL query file | +| `--raw-dataset NAME` | Dataset name (e.g., `edgeandnode/ethereum_mainnet`, `anvil`) | + +### Optional Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--amp-server URL` | `grpc://127.0.0.1:1602` | AMP server URL (use `grpc+tls://gateway.amp.staging.thegraph.com:443` for staging) | +| `--kafka-brokers` | `localhost:9092` | Kafka broker addresses | +| `--network NAME` | `anvil` | Network identifier (e.g., `ethereum-mainnet`, `anvil`) | +| `--start-block N` | Resume from state | Block number or `latest` to start from | +| `--reorg-topic NAME` | Same as `--topic` | Separate topic for reorg messages | +| `--label-csv PATH` | - | CSV file for data enrichment | +| `--state-dir PATH` | `.amp_state` | Directory for LMDB state storage | +| `--auth` | - | Enable auth using `~/.amp/cache` or `AMP_AUTH_TOKEN` env var | +| `--auth-token TOKEN` | - | Explicit auth token | + +## Message Format + +### Data Messages + +Each row is sent as JSON with `_type: 'data'`: + +```json +{ + "_type": "data", + "block_num": 19000000, + "tx_hash": "0x123...", + "address": "0xabc...", + "data": "0x..." +} +``` + +### Reorg Messages + +On blockchain reorganizations, reorg events are sent: + +```json +{ + "_type": "reorg", + "network": "ethereum-mainnet", + "start_block": 19000100, + "end_block": 19000110, + "last_valid_hash": "0xabc123..." +} +``` + +Consumers should invalidate data in the specified block range. Use `--reorg-topic` to send these to a separate topic (useful for Snowflake Kafka connector which requires strict schema per topic). + +## Examples + +### Stream ERC20 Transfers + +Stream ERC20 transfer events with the activity schema: + +```bash +uv run python apps/kafka_streaming_loader.py \ + --amp-server 'grpc+tls://gateway.amp.staging.thegraph.com:443' \ + --kafka-brokers localhost:9092 \ + --topic erc20_transfers \ + --query-file apps/queries/erc20_transfers_activity.sql \ + --raw-dataset 'edgeandnode/ethereum_mainnet' \ + --network ethereum-mainnet +``` + +#### Token Metadata Enrichment (Optional) + +To enrich transfer events with token metadata (symbol, name, decimals), add a CSV file: + +1. **Obtain the token metadata CSV**: + - Download from your token metadata source + - Or export from a database with token information + - Required columns: `token_address`, `symbol`, `name`, `decimals` + +2. **Place the CSV file** in the `data/` directory: + ```bash + mkdir -p data + # Copy your CSV file + cp /path/to/your/tokens.csv data/eth_mainnet_token_metadata.csv + ``` + +3. **Run with label enrichment**: + ```bash + uv run python apps/kafka_streaming_loader.py \ + --amp-server 'grpc+tls://gateway.amp.staging.thegraph.com:443' \ + --kafka-brokers localhost:9092 \ + --topic erc20_transfers \ + --query-file apps/queries/erc20_transfers_activity.sql \ + --raw-dataset 'edgeandnode/ethereum_mainnet' \ + --network ethereum-mainnet \ + --label-csv data/eth_mainnet_token_metadata.csv + ``` + +**CSV Format Example**: +```csv +token_address,symbol,name,decimals +0xe0f066cb646256d33cae9a32c7b144ccbd248fdd,gg unluck,gg unluck,18 +0xabb2a7bec4604491e85a959177cc0e95f60c6bd5,RTX,Remittix,3 +``` + +Without the CSV file, `token_symbol`, `token_name`, and `token_decimals` will be `null` in the output. + +### Stream from Latest Block + +```bash +uv run python apps/kafka_streaming_loader.py \ + --amp-server 'grpc+tls://gateway.amp.staging.thegraph.com:443' \ + --kafka-brokers localhost:9092 \ + --topic eth_live_logs \ + --query-file apps/queries/all_logs.sql \ + --raw-dataset 'edgeandnode/ethereum_mainnet' \ + --network ethereum-mainnet \ + --start-block latest \ + --auth +``` + +### Local Development (Anvil) + +```bash +uv run python apps/kafka_streaming_loader.py \ + --topic anvil_logs \ + --query-file apps/queries/anvil_logs.sql \ + --raw-dataset anvil \ + --start-block 0 +``` + +## Consuming Messages + +Use the consumer script to view messages: + +```bash +# Basic usage +uv run python apps/kafka_consumer.py anvil_logs + +# Custom broker +uv run python apps/kafka_consumer.py anvil_logs localhost:9092 + +# Custom consumer group +uv run python apps/kafka_consumer.py anvil_logs localhost:9092 my-group +``` + +## Docker Usage + +### Build the loader image + +```bash +docker build -f Dockerfile.kafka -t amp-kafka . +``` + +### Quick demo: local Kafka with Docker + +This section runs a single-broker Kafka in Docker for quick testing. For production, point `--kafka-brokers` at your real Kafka cluster and skip this section. + +Start a single-broker Kafka using `confluentinc/cp-kafka`: + +```bash +docker network create kafka-net + +docker run -d --name kafka --network kafka-net -p 9092:9092 \ + -e KAFKA_NODE_ID=1 \ + -e KAFKA_PROCESS_ROLES=broker,controller \ + -e KAFKA_CONTROLLER_QUORUM_VOTERS=1@kafka:9093 \ + -e KAFKA_LISTENERS=PLAINTEXT://0.0.0.0:29092,CONTROLLER://0.0.0.0:9093,EXTERNAL://0.0.0.0:9092 \ + -e KAFKA_ADVERTISED_LISTENERS=PLAINTEXT://kafka:29092,EXTERNAL://localhost:9092 \ + -e KAFKA_LISTENER_SECURITY_PROTOCOL_MAP=PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT,EXTERNAL:PLAINTEXT \ + -e KAFKA_CONTROLLER_LISTENER_NAMES=CONTROLLER \ + -e KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR=1 \ + -e KAFKA_TRANSACTION_STATE_LOG_MIN_ISR=1 \ + -e CLUSTER_ID=MkU3OEVBNTcwNTJENDM2Qk \ + confluentinc/cp-kafka:latest +``` + +This configures two listeners: +- `kafka:29092` — for containers on the `kafka-net` network (use this from the loader) +- `localhost:9092` — for host access (use this from `uv run` or the consumer script) + +### Run the loader + +```bash +docker run -d \ + --name amp-kafka-loader \ + --network kafka-net \ + -e AMP_AUTH_TOKEN \ + -v $(pwd)/apps/queries:/data/queries \ + -v $(pwd)/.amp_state:/data/state \ + amp-kafka \ + --amp-server 'grpc+tls://gateway.amp.staging.thegraph.com:443' \ + --kafka-brokers kafka:29092 \ + --topic erc20_transfers \ + --query-file /data/queries/erc20_transfers_activity.sql \ + --raw-dataset 'edgeandnode/ethereum_mainnet' \ + --network ethereum-mainnet \ + --state-dir /data/state \ + --start-block latest \ + --auth + +# Check logs +docker logs -f amp-kafka-loader +``` + +### Consume messages from host + +```bash +uv run python apps/kafka_consumer.py erc20_transfers localhost:9092 +``` + +## Getting Help + +```bash +# View all options +uv run python apps/kafka_streaming_loader.py --help + +# View this guide +cat apps/kafka_streaming_loader_guide.md +``` diff --git a/apps/queries/anvil_logs.sql b/apps/queries/anvil_logs.sql new file mode 100644 index 0000000..b08f3db --- /dev/null +++ b/apps/queries/anvil_logs.sql @@ -0,0 +1,7 @@ +SELECT + block_num, + tx_hash, + log_index, + address, + topic0 +FROM anvil.logs diff --git a/apps/queries/erc20_transfers_activity.sql b/apps/queries/erc20_transfers_activity.sql new file mode 100644 index 0000000..06ab322 --- /dev/null +++ b/apps/queries/erc20_transfers_activity.sql @@ -0,0 +1,69 @@ +-- ERC20 Transfer Events Query (Activity Schema) +-- +-- This query decodes ERC20 Transfer events from raw Ethereum logs +-- and formats them according to the activity schema for Kafka streaming. +-- +-- Output Schema: +-- - version: Schema version (1.0) +-- - chain: Chain identifier (e.g., 1 for Ethereum mainnet) +-- - block_num: Block number +-- - block_hash: Block hash (0x prefixed hex) +-- - transaction: Transaction hash (0x prefixed hex) +-- - activity_identity: Activity identity type (e.g., 'wallet') +-- - activity_type: Activity type (e.g., 'payment') +-- - activity_address: Primary address involved in activity +-- - activity_operations: Array of operations with debited/credited amounts +-- - token_address: ERC20 token contract address (0x prefixed hex) +-- - token_symbol: Token symbol (from label join) +-- - token_name: Token name (from label join) +-- - token_decimals: Token decimals (from label join) +-- +-- Required columns for parallel loading: +-- - block_num: Used for partitioning across workers +-- +-- Label join column (if using --label-csv): +-- - token_address: Binary address of the ERC20 token contract +-- +-- Example usage: +-- uv run python apps/kafka_streaming_loader.py \ +-- --amp-server 'grpc+tls://gateway.amp.staging.thegraph.com:443' \ +-- --kafka-brokers localhost:9092 \ +-- --topic erc20_transfers \ +-- --query-file apps/queries/erc20_transfers_activity.sql \ +-- --raw-dataset 'edgeandnode/ethereum_mainnet' \ +-- --network ethereum-mainnet \ +-- --label-csv data/eth_mainnet_token_metadata.csv + +select + 1.0 as version, + 1 as chain, + l.block_num, + l.block_hash, + l.tx_hash as transaction, + 'wallet' as activity_identity, + 'payment' as activity_type, + evm_decode(l.topic1, l.topic2, l.topic3, l.data, 'Transfer(address indexed from, address indexed to, uint256 value)')['from'] as activity_address, + [ + struct( + concat('log:', cast(l.block_num as string), ':', cast(l.log_index as string)) as id, + l.address as token, + evm_decode(l.topic1, l.topic2, l.topic3, l.data, 'Transfer(address indexed from, address indexed to, uint256 value)')['from'] as address, + -cast(evm_decode(l.topic1, l.topic2, l.topic3, l.data, 'Transfer(address indexed from, address indexed to, uint256 value)')['value'] as double) as amount, + 'debited' as type + ), + struct( + '' as id, + l.address as token, + evm_decode(l.topic1, l.topic2, l.topic3, l.data, 'Transfer(address indexed from, address indexed to, uint256 value)')['to'] as address, + cast(evm_decode(l.topic1, l.topic2, l.topic3, l.data, 'Transfer(address indexed from, address indexed to, uint256 value)')['value'] as double) as amount, + 'credited' as type + ) + ] as activity_operations, + l.address as token_address, + cast(null as string) as token_symbol, + cast(null as string) as token_name, + cast(null as int) as token_decimals +from "edgeandnode/ethereum_mainnet".logs l +where + l.topic0 = evm_topic('Transfer(address indexed from, address indexed to, uint256 value)') and + l.topic3 IS NULL diff --git a/apps/test_kafka_query.py b/apps/test_kafka_query.py new file mode 100644 index 0000000..8bc7422 --- /dev/null +++ b/apps/test_kafka_query.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +Test ERC20 query with label joining +""" + +import json +import os +import time + +from kafka import KafkaConsumer + +from amp.client import Client +from amp.loaders.types import LabelJoinConfig + +# Connect to Amp server +server_url = os.getenv('AMP_SERVER_URL', 'grpc://34.27.238.174:80') +print(f'Connecting to {server_url}...') +client = Client(server_url) +print('✅ Connected!') + +# Simple ERC20 transfer query +transfer_sig = 'Transfer(address indexed from, address indexed to, uint256 value)' +query = f""" + SELECT + block_num, + tx_hash, + address as token_address, + evm_decode(topic1, topic2, topic3, data, '{transfer_sig}')['from'] as from_address, + evm_decode(topic1, topic2, topic3, data, '{transfer_sig}')['to'] as to_address, + evm_decode(topic1, topic2, topic3, data, '{transfer_sig}')['value'] as value + FROM eth_firehose.logs + WHERE topic0 = evm_topic('{transfer_sig}') + AND topic3 IS NULL + LIMIT 10 +""" + +print('\nRunning query...') +result = client.get_sql(query, read_all=True) + +print(f'Got {result.num_rows} rows') +print(f'Columns: {result.schema.names}') + +print('Testing label join') + + +csv_path = 'data/eth_mainnet_token_metadata.csv' +client.configure_label('tokens', csv_path) +print(f'Loaded {len(client.label_manager.get_label("tokens"))} tokens from CSV') + +label_config = LabelJoinConfig( + label_name='tokens', + label_key_column='token_address', + stream_key_column='token_address', +) + +print('Configured label join: will add symbol, name, decimals columns') + + +print('Loading to Kafka with labels') + + +kafka_config = { + 'bootstrap_servers': os.getenv('KAFKA_BOOTSTRAP_SERVERS', 'localhost:9092'), + 'client_id': 'amp-erc20-loader', +} +client.configure_connection('kafka', 'kafka', kafka_config) + +results = list( + client.sql(query).load( + connection='kafka', + destination='erc20_transfers', + label_config=label_config, + ) +) + +total_rows = sum(r.rows_loaded for r in results if r.success) +print(f'\nLoaded {total_rows} enriched rows to Kafka topic "erc20_transfers"') + +print('\n' + '=' * 60) +print('Reading back from Kafka') +print('=' * 60) + +time.sleep(1) + +consumer = KafkaConsumer( + 'erc20_transfers', + bootstrap_servers=kafka_config['bootstrap_servers'], + auto_offset_reset='earliest', + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + consumer_timeout_ms=5000, +) + +print('\nConsuming messages from topic "erc20_transfers":\n') +msg_count = 0 +for message in consumer: + msg_count += 1 + data = message.value + print(f'Message {msg_count}:') + print(f' token_address: {data.get("token_address")}') + print(f' symbol: {data.get("symbol")}') + print(f' name: {data.get("name")}') + print(f' decimals: {data.get("decimals")}') + print(f' value: {data.get("value")}') + print(f' from_address: {data.get("from_address")}') + print() + +consumer.close() +print(f'Read {msg_count} messages from Kafka') diff --git a/notebooks/kafka_streaming.py b/notebooks/kafka_streaming.py new file mode 100644 index 0000000..8dfca98 --- /dev/null +++ b/notebooks/kafka_streaming.py @@ -0,0 +1,122 @@ +import marimo + +__generated_with = '0.17.0' +app = marimo.App(width='medium') + + +@app.cell +def _(): + import marimo as mo + + from amp.client import Client + + return Client, mo + + +@app.cell(hide_code=True) +def _(mo): + mo.md( + r""" + # Kafka Streaming Example + + This notebook demonstrates continuous streaming from Flight SQL to Kafka with reorg detection. + """ + ) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Setup""") + return + + +@app.cell +def _(Client): + client = Client('grpc://127.0.0.1:1602') + return (client,) + + +@app.cell +def _(client): + client.configure_connection( + 'my_kafka', + 'kafka', + {'bootstrap_servers': 'localhost:9092', 'client_id': 'amp-streaming-client', 'key_field': 'block_num'}, + ) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md( + r""" + ## Streaming Query + + This query uses `SETTINGS stream = true` to continuously stream new blocks as they arrive. + The loader will automatically handle blockchain reorganizations. + """ + ) + return + + +@app.cell +def _(client): + streaming_results = client.sql( + """ + SELECT + block_num, + log_index + FROM anvil.logs + """ + ).load( + 'my_kafka', + 'eth_logs_stream', + stream=True, + create_table=True, + ) + return (streaming_results,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md( + r""" + ## Monitor Stream + + This cell will continuously print results as they arrive. It starts a Kafka consumer to print + the results as they come in. + """ + ) + return + + +@app.cell +def _(streaming_results): + import json + import threading + + from kafka import KafkaConsumer + + def consume_kafka(): + consumer = KafkaConsumer( + 'eth_logs_stream', + bootstrap_servers='localhost:9092', + auto_offset_reset='latest', + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + ) + print('Kafka Consumer started') + for message in consumer: + print(f'Consumed: {message.value}') + + consumer_thread = threading.Thread(target=consume_kafka, daemon=True) + consumer_thread.start() + + print('Kafka Producer started') + for result in streaming_results: + print(f'Produced: {result}') + return + + +if __name__ == '__main__': + app.run() diff --git a/notebooks/test_loaders.py b/notebooks/test_loaders.py index af8d7f9..08f9bda 100644 --- a/notebooks/test_loaders.py +++ b/notebooks/test_loaders.py @@ -1,6 +1,6 @@ import marimo -__generated_with = "0.14.16" +__generated_with = "0.17.0" app = marimo.App(width="full") @@ -89,7 +89,7 @@ def _(mo): create_table=True, ) """, - name='_', + name="_" ) @@ -97,7 +97,7 @@ def _(mo): def _(psql_load_results): for p_result in psql_load_results: print(p_result) - return (p_result,) + return @app.cell(hide_code=True) @@ -120,7 +120,7 @@ def _(client): def _(redis_load_results): for r_result in redis_load_results: print(r_result) - return (r_result,) + return @app.cell(hide_code=True) @@ -149,7 +149,7 @@ def _(client): else: # Single result print(f'Total: {result.rows_loaded} rows') - return batch_result, result + return (batch_result,) @app.cell @@ -291,7 +291,7 @@ def _(lmdb_load_result): def _(batch_result, lmdb_load_result): for lmdb_batch_result in lmdb_load_result: print(f'Batch: {batch_result.rows_loaded} rows') - return (lmdb_batch_result,) + return @app.cell @@ -325,7 +325,7 @@ def _(env): myList = [ key for key, _ in txn.cursor() ] print(myList) print(len(myList)) - return myList, txn + return @app.cell @@ -340,7 +340,74 @@ def _(env, pa): batch = reader.read_next_batch() print(batch) - return batch, key, open_txn, reader, value + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""# Kafka""") + return + + +@app.cell +def _(client): + # Configure Kafka connection + client.configure_connection( + 'my_kafka', + 'kafka', + { + 'bootstrap_servers': 'localhost:9092', + 'client_id': 'amp-test-client', + 'key_field': 'id' + } + ) + return + + +@app.cell +def _(client): + # Load data to Kafka topic + kafka_load_results = client.sql('select * from eth_firehose.logs limit 100').load( + 'my_kafka', + 'test_logs', + create_table=True, + ) + return (kafka_load_results,) + + +@app.cell +def _(kafka_load_results): + # Check results + for k_result in kafka_load_results: + print(f'Kafka batch: {k_result.rows_loaded} rows loaded, duration: {k_result.duration:.2f}s') + return (k_result,) + + +@app.cell +def _(): + from kafka import KafkaConsumer + import json + + consumer = KafkaConsumer( + 'test_logs', + bootstrap_servers='localhost:9092', + auto_offset_reset='earliest', + consumer_timeout_ms=3000, + value_deserializer=lambda m: json.loads(m.decode('utf-8')) + ) + + messages = list(consumer) + consumer.close() + + print(f"Total messages in Kafka: {len(messages)}") + print(f"\nFirst message:") + if messages: + msg = messages[0].value + print(f" Block: {msg.get('block_num')}") + print(f" Timestamp: {msg.get('timestamp')}") + print(f" Address: {msg.get('address')}") + + return @app.cell diff --git a/pyproject.toml b/pyproject.toml index 258bc24..efa3548 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,15 +66,20 @@ lmdb = [ "lmdb>=1.4.0", ] +kafka = [ + "kafka-python>=2.2.15", +] + all_loaders = [ - "psycopg2-binary>=2.9.0", # PostgreSQL - "redis>=4.5.0", # Redis - "deltalake>=1.0.2", # Delta Lake (consistent version) + "psycopg2-binary>=2.9.0", # PostgreSQL + "redis>=4.5.0", # Redis + "deltalake>=1.0.2", # Delta Lake (consistent version) "pyiceberg[sql-sqlite]>=0.10.0", # Apache Iceberg "pydantic>=2.0,<2.12", # PyIceberg 0.10.0 compatibility "snowflake-connector-python>=4.0.0", # Snowflake "snowpipe-streaming>=1.0.0", # Snowpipe Streaming API "lmdb>=1.4.0", # LMDB + "kafka-python>=2.2.15", ] test = [ diff --git a/src/amp/client.py b/src/amp/client.py index 2eee462..9d463df 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -769,26 +769,23 @@ def query_and_load_streaming( self.logger.info(f'Starting streaming query to {loader_type}:{destination}') - # Create loader instance early to access checkpoint store + # Create loader instance early to access state store loader_instance = create_loader(loader_type, loader_config, label_manager=self.label_manager) - # Load checkpoint and create resume watermark if enabled (default: enabled) + # Load resume position from state store if enabled (default: enabled) if resume_watermark is None and kwargs.get('resume', True): try: - checkpoint = loader_instance.checkpoint_store.load(connection_name, destination) - - if checkpoint: - resume_watermark = checkpoint.to_resume_watermark() - checkpoint_type = 'reorg checkpoint' if checkpoint.is_reorg else 'checkpoint' - self.logger.info( - f'Resuming from {checkpoint_type}: {len(checkpoint.ranges)} ranges, ' - f'timestamp {checkpoint.timestamp}' - ) - if checkpoint.is_reorg: - resume_points = ', '.join(f'{r.network}:{r.start}' for r in checkpoint.ranges) - self.logger.info(f'Reorg resume points: {resume_points}') + loader_instance.connect() + resume_watermark = loader_instance.state_store.get_resume_position( + connection_name, destination, detect_gaps=False + ) + + if resume_watermark: + self.logger.info(f'Resuming from state store: {len(resume_watermark.ranges)} ranges') + resume_points = ', '.join(f'{r.network}:{r.end}' for r in resume_watermark.ranges) + self.logger.info(f'Resume points (max processed blocks): {resume_points}') except Exception as e: - self.logger.warning(f'Failed to load checkpoint, starting from beginning: {e}') + self.logger.warning(f'Failed to load resume position, starting from beginning: {e}') try: # Execute streaming query with Flight SQL @@ -796,20 +793,22 @@ def query_and_load_streaming( command_query = FlightSql_pb2.CommandStatementQuery() command_query.query = query - # Add resume watermark if provided - if resume_watermark: - # TODO: Add watermark to query metadata when Flight SQL supports it - self.logger.info(f'Resuming stream from watermark: {resume_watermark}') - # Wrap the CommandStatementQuery in an Any type any_command = Any() any_command.Pack(command_query) cmd = any_command.SerializeToString() + # Prepare Flight call options with headers + call_options = None + if resume_watermark: + watermark_json = resume_watermark.to_json() + self.logger.info(f'Resuming stream from watermark: {watermark_json}') + call_options = flight.FlightCallOptions(headers=[(b'amp-resume', watermark_json.encode('utf-8'))]) + self.logger.info('Establishing Flight SQL connection...') flight_descriptor = flight.FlightDescriptor.for_command(cmd) - info = self.conn.get_flight_info(flight_descriptor) - reader = self.conn.do_get(info.endpoints[0].ticket) + info = self.conn.get_flight_info(flight_descriptor, options=call_options) + reader = self.conn.do_get(info.endpoints[0].ticket, options=call_options) # Create streaming iterator stream_iterator = StreamingResultIterator(reader) diff --git a/src/amp/loaders/base.py b/src/amp/loaders/base.py index 05847eb..a2775eb 100644 --- a/src/amp/loaders/base.py +++ b/src/amp/loaders/base.py @@ -78,11 +78,19 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None: else: self.state_store = NullStreamStateStore() + # Track tables that have undergone crash recovery + self._crash_recovery_done: set[str] = set() + @property def is_connected(self) -> bool: """Check if the loader is connected to the target system.""" return self._is_connected + @property + def loader_type(self) -> str: + """Get the loader type identifier (e.g., 'postgresql', 'redis').""" + return self.__class__.__name__.replace('Loader', '').lower() + def _parse_config(self, config: Dict[str, Any]) -> TConfig: """ Parse configuration into loader-specific format. @@ -455,11 +463,21 @@ def load_stream_continuous( if not self._is_connected: self.connect() + connection_name = kwargs.get('connection_name') + if connection_name is None: + connection_name = self.loader_type + + if table_name not in self._crash_recovery_done: + self.logger.info(f'Running crash recovery for table {table_name} (connection: {connection_name})') + self._rewind_to_watermark(table_name, connection_name) + self._crash_recovery_done.add(table_name) + else: + self.logger.info(f'Crash recovery already done for table {table_name}') + rows_loaded = 0 start_time = time.time() batch_count = 0 reorg_count = 0 - connection_name = kwargs.get('connection_name', 'unknown') worker_id = kwargs.get('worker_id', 0) try: @@ -793,6 +811,68 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, 'Streaming with reorg detection requires implementing this method.' ) + def _rewind_to_watermark(self, table_name: str, connection_name: Optional[str] = None) -> None: + """ + Reset state and data to the last checkpointed watermark. + + Removes any data written after the last completed watermark, + ensuring resumable streams start from a consistent state. + + This handles crash recovery by removing uncommitted data from + incomplete microbatches between watermarks. + + Args: + table_name: Table to clean up. + connection_name: Connection identifier. If None, uses default. + """ + if not self.state_enabled: + self.logger.debug('State tracking disabled, skipping crash recovery') + return + + if connection_name is None: + connection_name = self.loader_type + + resume_pos = self.state_store.get_resume_position(connection_name, table_name) + if not resume_pos: + self.logger.debug(f'No watermark found for {table_name}, skipping crash recovery') + return + + for range_obj in resume_pos.ranges: + from_block = range_obj.end + 1 + + self.logger.info( + f'Crash recovery: Cleaning up {table_name} data for {range_obj.network} from block {from_block} onwards' + ) + + invalidation_ranges = [ + BlockRange( + network=range_obj.network, + start=from_block, + end=from_block, + hash=range_obj.hash, + prev_hash=range_obj.prev_hash, + ) + ] + + try: + self._handle_reorg(invalidation_ranges, table_name, connection_name) + self.logger.info(f'Crash recovery completed for {range_obj.network} in {table_name}') + + except NotImplementedError: + invalidated = self.state_store.invalidate_from_block( + connection_name, table_name, range_obj.network, from_block + ) + + if invalidated: + self.logger.warning( + f'Crash recovery: Cleared {len(invalidated)} batches from state ' + f'for {range_obj.network} but cannot delete data from {table_name}. ' + f'{self.__class__.__name__} does not support data deletion. ' + f'Duplicates may occur on resume.' + ) + else: + self.logger.debug(f'No uncommitted batches found for {range_obj.network}') + def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch: """ Add metadata columns for streaming data with compact batch identification. diff --git a/src/amp/loaders/implementations/__init__.py b/src/amp/loaders/implementations/__init__.py index 28617b0..75238bc 100644 --- a/src/amp/loaders/implementations/__init__.py +++ b/src/amp/loaders/implementations/__init__.py @@ -21,7 +21,7 @@ try: from .iceberg_loader import IcebergLoader -except ImportError: +except Exception: IcebergLoader = None try: @@ -34,11 +34,10 @@ except ImportError: LMDBLoader = None -# Add any other loaders here -# try: -# from .snowflake_loader import SnowflakeLoader -# except ImportError: -# SnowflakeLoader = None +try: + from .kafka_loader import KafkaLoader +except ImportError: + KafkaLoader = None __all__ = [] @@ -55,3 +54,5 @@ __all__.append('SnowflakeLoader') if LMDBLoader: __all__.append('LMDBLoader') +if KafkaLoader: + __all__.append('KafkaLoader') diff --git a/src/amp/loaders/implementations/kafka_loader.py b/src/amp/loaders/implementations/kafka_loader.py new file mode 100644 index 0000000..b86f346 --- /dev/null +++ b/src/amp/loaders/implementations/kafka_loader.py @@ -0,0 +1,178 @@ +import json +from dataclasses import dataclass, fields +from typing import Any, Dict, List, Optional + +import pyarrow as pa +from kafka import KafkaProducer + +from ...streaming.lmdb_state import LMDBStreamStateStore +from ...streaming.types import BlockRange +from ..base import DataLoader, LoadMode + + +@dataclass +class KafkaConfig: + bootstrap_servers: str + client_id: str = 'amp-kafka-loader' + key_field: Optional[str] = 'id' + reorg_topic: Optional[str] = None + + +KAFKA_CONFIG_FIELDS = {f.name for f in fields(KafkaConfig)} +RESERVED_CONFIG_FIELDS = {'resilience', 'state', 'checkpoint', 'idempotency'} + + +class KafkaLoader(DataLoader[KafkaConfig]): + SUPPORTED_MODES = {LoadMode.APPEND} + REQUIRES_SCHEMA_MATCH = False + SUPPORTS_TRANSACTIONS = True + + def __init__(self, config: Dict[str, Any], label_manager=None) -> None: + self._extra_producer_config = { + k: v for k, v in config.items() if k not in KAFKA_CONFIG_FIELDS and k not in RESERVED_CONFIG_FIELDS + } + super().__init__(config, label_manager) + self._producer = None + + def _get_required_config_fields(self) -> list[str]: + return ['bootstrap_servers'] + + def connect(self) -> None: + try: + producer_config = { + **self._extra_producer_config, + 'bootstrap_servers': self.config.bootstrap_servers, + 'client_id': self.config.client_id, + 'value_serializer': lambda x: json.dumps(x, default=str).encode('utf-8'), + 'transactional_id': f'{self.config.client_id}-txn', + } + if self._extra_producer_config: + self.logger.info(f'Extra Kafka config: {list(self._extra_producer_config.keys())}') + self._producer = KafkaProducer(**producer_config) + + self._producer.init_transactions() + + metadata = self._producer.bootstrap_connected() + self.logger.info(f'Connection status: {metadata}') + self.logger.info(f'Connected to Kafka at {self.config.bootstrap_servers}') + self.logger.info(f'Client ID: {self.config.client_id}') + + if self.state_enabled and self.state_storage == 'lmdb': + self.state_store = LMDBStreamStateStore( + connection_name=self.config.client_id, + data_dir=self.state_data_dir, + ) + self.logger.info(f'Initialized LMDB state store at {self.state_store.data_dir}') + + self._is_connected = True + + except Exception as e: + if self._producer: + self._producer.close() + self._producer = None + self.logger.error(f'Failed to connect to Kafka: {e}') + raise + + def disconnect(self) -> None: + if self._producer: + self._producer.close() + self._producer = None + + if isinstance(self.state_store, LMDBStreamStateStore): + self.state_store.close() + + self._is_connected = False + self.logger.info('Disconnected from Kafka') + + def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None: + self.logger.info(f'Kafka topic {table_name} will be auto-created on first message send') + pass + + def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int: + if not self._producer: + raise RuntimeError('Producer not connected. Call connect() first.') + + data_dict = batch.to_pydict() + num_rows = batch.num_rows + + if num_rows == 0: + return 0 + + self._producer.begin_transaction() + try: + for i in range(num_rows): + row = {field: values[i] for field, values in data_dict.items()} + row['_type'] = 'data' + + key = self._extract_message_key(row) + + self._producer.send(topic=table_name, key=key, value=row) + + self._producer.commit_transaction() + self.logger.debug(f'Committed transaction with {num_rows} messages to topic {table_name}') + + except Exception as e: + self._producer.abort_transaction() + self.logger.error(f'Transaction aborted due to error: {e}') + raise + + return num_rows + + def _extract_message_key(self, row: Dict[str, Any]) -> Optional[bytes]: + if not self.config.key_field or self.config.key_field not in row: + return None + + key_value = row[self.config.key_field] + if key_value is None: + return None + + return str(key_value).encode('utf-8') + + def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, connection_name: str) -> None: + """ + Handle blockchain reorganization by sending reorg events to Kafka. + + Reorg events are sent as special messages with _type='reorg' so consumers + can detect and handle invalidated block ranges. + + Args: + invalidation_ranges: List of block ranges to invalidate + table_name: The Kafka topic name (used if reorg_topic not configured) + connection_name: Connection identifier (required by base class interface) + """ + if not invalidation_ranges: + return + + if not self._producer: + self.logger.warning('Producer not connected, skipping reorg handling') + return + + reorg_topic = self.config.reorg_topic or table_name + + self._producer.begin_transaction() + try: + for invalidation_range in invalidation_ranges: + reorg_message = { + '_type': 'reorg', + 'network': invalidation_range.network, + 'start_block': invalidation_range.start, + 'end_block': invalidation_range.end, + 'last_valid_hash': invalidation_range.hash, + } + + self._producer.send( + topic=reorg_topic, key=f'reorg:{invalidation_range.network}'.encode('utf-8'), value=reorg_message + ) + + self.logger.info( + f'Sent reorg event to {reorg_topic}: ' + f'{invalidation_range.network} blocks {invalidation_range.start}-{invalidation_range.end}' + ) + + self._producer.commit_transaction() + self.logger.info(f'Committed {len(invalidation_ranges)} reorg events to {reorg_topic}') + + except Exception as e: + self._producer.abort_transaction() + self.logger.error(f'Reorg transaction aborted due to error: {e}') + raise diff --git a/src/amp/streaming/reorg.py b/src/amp/streaming/reorg.py index 81702b1..f0ce300 100644 --- a/src/amp/streaming/reorg.py +++ b/src/amp/streaming/reorg.py @@ -46,6 +46,12 @@ def __next__(self) -> ResponseBatch: KeyboardInterrupt: When user cancels the stream """ try: + # Check if we have a pending batch from a previous reorg detection + if hasattr(self, '_pending_batch'): + pending = self._pending_batch + delattr(self, '_pending_batch') + return pending + # Get next batch from underlying stream batch = next(self.stream_iterator) @@ -63,13 +69,6 @@ def __next__(self) -> ResponseBatch: self._pending_batch = batch return ResponseBatch.reorg_batch(invalidation_ranges) - # Check if we have a pending batch from a previous reorg detection - # REVIEW: I think we should remove this - if hasattr(self, '_pending_batch'): - pending = self._pending_batch - delattr(self, '_pending_batch') - return pending - # Normal case - just return the data batch return batch @@ -107,6 +106,7 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]: network=current_range.network, start=current_range.start, end=max(current_range.end, prev_range.end), + hash=prev_range.hash, ) invalidation_ranges.append(invalidation) diff --git a/src/amp/streaming/types.py b/src/amp/streaming/types.py index ba35919..cb1517f 100644 --- a/src/amp/streaming/types.py +++ b/src/amp/streaming/types.py @@ -164,21 +164,19 @@ class ResumeWatermark: """Watermark for resuming streaming queries""" ranges: List[BlockRange] + # TODO: timestamp and sequence are unused. Remove? timestamp: Optional[str] = None sequence: Optional[int] = None def to_json(self) -> str: - """Serialize to JSON string for HTTP headers""" - data = {'ranges': [r.to_dict() for r in self.ranges]} - if self.timestamp: - data['timestamp'] = self.timestamp - if self.sequence is not None: - data['sequence'] = self.sequence - return json.dumps(data) + """Serialize to JSON string for HTTP headers. - @classmethod - def from_json(cls, json_str: str) -> 'ResumeWatermark': - """Deserialize from JSON string""" - data = json.loads(json_str) - ranges = [BlockRange.from_dict(r) for r in data['ranges']] - return cls(ranges=ranges, timestamp=data.get('timestamp'), sequence=data.get('sequence')) + Server expects format: {"network_name": {"number": block_num, "hash": "0x..."}, ...} + """ + data = {} + for r in self.ranges: + if r.hash is None: + raise ValueError(f"BlockRange for network '{r.network}' must have a hash for watermark") + + data[r.network] = {'number': r.end, 'hash': r.hash} + return json.dumps(data) diff --git a/tests/conftest.py b/tests/conftest.py index f0bbc4d..a984cb4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,7 @@ # Import testcontainers conditionally if USE_TESTCONTAINERS: try: + from testcontainers.kafka import KafkaContainer from testcontainers.postgres import PostgresContainer from testcontainers.redis import RedisContainer @@ -213,6 +214,51 @@ def redis_streaming_config(redis_test_config): } +@pytest.fixture(scope='session') +def kafka_container(): + """Kafka container for integration tests""" + if not TESTCONTAINERS_AVAILABLE: + pytest.skip('Testcontainers not available') + + # Configure Kafka for transactions in single-broker setup + # These settings are required for transactional producers to work + container = KafkaContainer(image='confluentinc/cp-kafka:7.6.0') + container.with_env('KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR', '1') + container.with_env('KAFKA_TRANSACTION_STATE_LOG_MIN_ISR', '1') + container.with_env('KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR', '1') + container.start() + + time.sleep(10) + + yield container + + container.stop() + + +@pytest.fixture(scope='session') +def kafka_config(): + """Kafka configuration from environment or defaults""" + return { + 'bootstrap_servers': os.getenv('KAFKA_BOOTSTRAP_SERVERS', 'localhost:9092'), + 'client_id': 'amp-test-client', + } + + +@pytest.fixture(scope='session') +def kafka_test_config(request): + """Kafka configuration from testcontainer or environment""" + if TESTCONTAINERS_AVAILABLE and USE_TESTCONTAINERS: + kafka_container = request.getfixturevalue('kafka_container') + bootstrap_servers = kafka_container.get_bootstrap_server() + + return { + 'bootstrap_servers': bootstrap_servers, + 'client_id': 'amp-test-client', + } + else: + return request.getfixturevalue('kafka_config') + + @pytest.fixture(scope='session') def delta_test_env(): """Create Delta Lake test environment for the session""" @@ -456,6 +502,7 @@ def pytest_configure(config): config.addinivalue_line('markers', 'performance: Performance and benchmark tests') config.addinivalue_line('markers', 'postgresql: Tests requiring PostgreSQL') config.addinivalue_line('markers', 'redis: Tests requiring Redis') + config.addinivalue_line('markers', 'kafka: Tests requiring Apache Kafka') config.addinivalue_line('markers', 'delta_lake: Tests requiring Delta Lake') config.addinivalue_line('markers', 'iceberg: Tests requiring Apache Iceberg') config.addinivalue_line('markers', 'snowflake: Tests requiring Snowflake') diff --git a/tests/integration/test_kafka_loader.py b/tests/integration/test_kafka_loader.py new file mode 100644 index 0000000..092d13c --- /dev/null +++ b/tests/integration/test_kafka_loader.py @@ -0,0 +1,305 @@ +import json +from unittest.mock import patch + +import pyarrow as pa +import pytest +from kafka import KafkaConsumer +from kafka.errors import KafkaError + +try: + from src.amp.loaders.implementations.kafka_loader import KafkaLoader + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch +except ImportError: + pytest.skip('amp modules not available', allow_module_level=True) + + +@pytest.mark.integration +@pytest.mark.kafka +class TestKafkaLoaderIntegration: + def test_loader_connection(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + + loader.connect() + assert loader._is_connected == True + assert loader._producer is not None + + loader.disconnect() + assert loader._is_connected == False + assert loader._producer is None + + def test_context_manager(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + + with loader: + assert loader._is_connected == True + assert loader._producer is not None + + assert loader._is_connected == False + + def test_load_batch(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + + batch = pa.RecordBatch.from_pydict( + {'id': [1, 2, 3], 'name': ['alice', 'bob', 'charlie'], 'value': [100, 200, 300]} + ) + + with loader: + result = loader.load_batch(batch, 'test_topic') + + assert result.success == True + assert result.rows_loaded == 3 + + def test_message_consumption_verification(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + topic_name = 'test_consumption_topic' + + batch = pa.RecordBatch.from_pydict( + { + 'id': [1, 2, 3], + 'name': ['alice', 'bob', 'charlie'], + 'score': [100, 200, 150], + 'active': [True, False, True], + } + ) + + with loader: + result = loader.load_batch(batch, topic_name) + + assert result.success is True + assert result.rows_loaded == 3 + + consumer = KafkaConsumer( + topic_name, + bootstrap_servers=kafka_test_config['bootstrap_servers'], + auto_offset_reset='earliest', + consumer_timeout_ms=5000, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + ) + + messages = list(consumer) + consumer.close() + + assert len(messages) == 3 + + for i, msg in enumerate(messages): + assert msg.key == str(i + 1).encode('utf-8') + assert msg.value['_type'] == 'data' + assert msg.value['id'] == i + 1 + assert msg.value['name'] in ['alice', 'bob', 'charlie'] + assert msg.value['score'] in [100, 200, 150] + assert msg.value['active'] in [True, False] + + msg1 = messages[0] + assert msg1.value['id'] == 1 + assert msg1.value['name'] == 'alice' + assert msg1.value['score'] == 100 + assert msg1.value['active'] is True + + msg2 = messages[1] + assert msg2.value['id'] == 2 + assert msg2.value['name'] == 'bob' + assert msg2.value['score'] == 200 + assert msg2.value['active'] is False + + msg3 = messages[2] + assert msg3.value['id'] == 3 + assert msg3.value['name'] == 'charlie' + assert msg3.value['score'] == 150 + assert msg3.value['active'] is True + + def test_handle_reorg(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + topic_name = 'test_reorg_topic' + + invalidation_ranges = [ + BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'), + BlockRange(network='polygon', start=500, end=600, hash='0xdef456'), + ] + + with loader: + loader._handle_reorg(invalidation_ranges, topic_name, 'test_connection') + + consumer = KafkaConsumer( + topic_name, + bootstrap_servers=kafka_test_config['bootstrap_servers'], + auto_offset_reset='earliest', + consumer_timeout_ms=5000, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + ) + + messages = list(consumer) + consumer.close() + + assert len(messages) == 2 + + msg1 = messages[0] + assert msg1.key == b'reorg:ethereum' + assert msg1.value['_type'] == 'reorg' + assert msg1.value['network'] == 'ethereum' + assert msg1.value['start_block'] == 100 + assert msg1.value['end_block'] == 200 + assert msg1.value['last_valid_hash'] == '0xabc123' + + msg2 = messages[1] + assert msg2.key == b'reorg:polygon' + assert msg2.value['_type'] == 'reorg' + assert msg2.value['network'] == 'polygon' + assert msg2.value['start_block'] == 500 + assert msg2.value['end_block'] == 600 + assert msg2.value['last_valid_hash'] == '0xdef456' + + def test_handle_reorg_separate_topic(self, kafka_test_config): + config_with_reorg_topic = { + **kafka_test_config, + 'reorg_topic': 'test_reorg_events', + } + loader = KafkaLoader(config_with_reorg_topic) + data_topic = 'test_data_topic_separate' + + batch = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) + invalidation_ranges = [ + BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'), + ] + + with loader: + loader.load_batch(batch, data_topic) + loader._handle_reorg(invalidation_ranges, data_topic, 'test_connection') + + data_consumer = KafkaConsumer( + data_topic, + bootstrap_servers=kafka_test_config['bootstrap_servers'], + auto_offset_reset='earliest', + consumer_timeout_ms=5000, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + ) + data_messages = list(data_consumer) + data_consumer.close() + + assert len(data_messages) == 2 + assert all(msg.value['_type'] == 'data' for msg in data_messages) + + reorg_consumer = KafkaConsumer( + 'test_reorg_events', + bootstrap_servers=kafka_test_config['bootstrap_servers'], + auto_offset_reset='earliest', + consumer_timeout_ms=5000, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + ) + reorg_messages = list(reorg_consumer) + reorg_consumer.close() + + assert len(reorg_messages) == 1 + assert reorg_messages[0].value['_type'] == 'reorg' + assert reorg_messages[0].value['network'] == 'ethereum' + assert reorg_messages[0].value['start_block'] == 100 + assert reorg_messages[0].value['end_block'] == 200 + + def test_streaming_with_reorg(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + topic_name = 'test_streaming_topic' + + data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) + data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) + data3 = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) + + response1 = ResponseBatch.data_batch( + data=data1, + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), + ) + + response2 = ResponseBatch.data_batch( + data=data2, + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=110, end=120, hash='0xdef456')]), + ) + + reorg_response = ResponseBatch.reorg_batch( + invalidation_ranges=[BlockRange(network='ethereum', start=110, end=200, hash='0xdef456')] + ) + + response3 = ResponseBatch.data_batch( + data=data3, + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=110, end=120, hash='0xnew123')]), + ) + + stream = [response1, response2, reorg_response, response3] + + with loader: + results = list(loader.load_stream_continuous(iter(stream), topic_name)) + + assert len(results) == 4 + assert results[0].success + assert results[0].rows_loaded == 2 + assert results[1].success + assert results[1].rows_loaded == 2 + assert results[2].success + assert results[2].is_reorg + assert results[3].success + assert results[3].rows_loaded == 2 + + consumer = KafkaConsumer( + topic_name, + bootstrap_servers=kafka_test_config['bootstrap_servers'], + auto_offset_reset='earliest', + consumer_timeout_ms=5000, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + ) + + messages = list(consumer) + consumer.close() + + assert len(messages) == 7 + + data_messages = [msg for msg in messages if msg.value.get('_type') == 'data'] + reorg_messages = [msg for msg in messages if msg.value.get('_type') == 'reorg'] + + assert len(data_messages) == 6 + assert len(reorg_messages) == 1 + + assert reorg_messages[0].key == b'reorg:ethereum' + assert reorg_messages[0].value['network'] == 'ethereum' + assert reorg_messages[0].value['start_block'] == 110 + assert reorg_messages[0].value['end_block'] == 200 + assert reorg_messages[0].value['last_valid_hash'] == '0xdef456' + + data_ids = [msg.value['id'] for msg in data_messages] + assert data_ids == [1, 2, 3, 4, 5, 6] + + def test_transaction_rollback_on_error(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + topic_name = 'test_rollback_topic' + + batch = pa.RecordBatch.from_pydict( + { + 'id': [1, 2, 3, 4, 5], + 'name': ['alice', 'bob', 'charlie', 'dave', 'eve'], + 'value': [100, 200, 300, 400, 500], + } + ) + + with loader: + call_count = [0] + + original_send = loader._producer.send + + def failing_send(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 3: + raise KafkaError('Simulated Kafka send failure') + return original_send(*args, **kwargs) + + with patch.object(loader._producer, 'send', side_effect=failing_send): + with pytest.raises(RuntimeError, match='FATAL: Permanent error loading batch'): + loader.load_batch(batch, topic_name) + + consumer = KafkaConsumer( + topic_name, + bootstrap_servers=kafka_test_config['bootstrap_servers'], + auto_offset_reset='earliest', + consumer_timeout_ms=5000, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + ) + + messages = list(consumer) + consumer.close() + + assert len(messages) == 0 diff --git a/tests/unit/test_crash_recovery.py b/tests/unit/test_crash_recovery.py new file mode 100644 index 0000000..76a4ae9 --- /dev/null +++ b/tests/unit/test_crash_recovery.py @@ -0,0 +1,108 @@ +""" +Unit tests for crash recovery via _rewind_to_watermark() method. + +These tests verify the crash recovery logic works correctly in isolation. +""" + +from unittest.mock import Mock + +import pytest + +from src.amp.streaming.types import BlockRange, ResumeWatermark +from tests.fixtures.mock_clients import MockDataLoader + + +@pytest.fixture +def mock_loader() -> MockDataLoader: + """Create a mock loader with state store""" + loader = MockDataLoader({'test': 'config'}) + loader.connect() + + loader.state_store = Mock() + loader.state_enabled = True + + return loader + + +@pytest.mark.unit +class TestCrashRecovery: + """Test _rewind_to_watermark() crash recovery method""" + + def test_rewind_with_no_state(self, mock_loader): + """Should return early if state_enabled=False""" + mock_loader.state_enabled = False + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader.state_store.get_resume_position.assert_not_called() + + def test_rewind_with_no_watermark(self, mock_loader): + """Should return early if no watermark exists""" + mock_loader.state_store.get_resume_position = Mock(return_value=None) + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader.state_store.get_resume_position.assert_called_once_with('test_conn', 'test_table') + + def test_rewind_calls_handle_reorg(self, mock_loader): + """Should call _handle_reorg with correct invalidation ranges""" + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock() + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader._handle_reorg.assert_called_once() + call_args = mock_loader._handle_reorg.call_args + invalidation_ranges = call_args[0][0] + assert len(invalidation_ranges) == 1 + assert invalidation_ranges[0].network == 'ethereum' + assert invalidation_ranges[0].start == 1011 + assert call_args[0][1] == 'test_table' + assert call_args[0][2] == 'test_conn' + + def test_rewind_handles_not_implemented(self, mock_loader): + """Should gracefully handle loaders without _handle_reorg""" + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock(side_effect=NotImplementedError()) + mock_loader.state_store.invalidate_from_block = Mock(return_value=[]) + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader.state_store.invalidate_from_block.assert_called_once_with( + 'test_conn', 'test_table', 'ethereum', 1011 + ) + + def test_rewind_with_multiple_networks(self, mock_loader): + """Should process ethereum and polygon separately""" + watermark = ResumeWatermark( + ranges=[ + BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc'), + BlockRange(network='polygon', start=2000, end=2010, hash='0xdef'), + ] + ) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock() + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + assert mock_loader._handle_reorg.call_count == 2 + + first_call = mock_loader._handle_reorg.call_args_list[0] + assert first_call[0][0][0].network == 'ethereum' + assert first_call[0][0][0].start == 1011 + + second_call = mock_loader._handle_reorg.call_args_list[1] + assert second_call[0][0][0].network == 'polygon' + assert second_call[0][0][0].start == 2011 + + def test_rewind_uses_default_connection_name(self, mock_loader): + """Should use default connection name from loader class""" + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock() + + mock_loader._rewind_to_watermark('test_table', connection_name=None) + + mock_loader.state_store.get_resume_position.assert_called_once_with('mockdata', 'test_table') diff --git a/tests/unit/test_streaming_types.py b/tests/unit/test_streaming_types.py index 47eede2..0eb09e9 100644 --- a/tests/unit/test_streaming_types.py +++ b/tests/unit/test_streaming_types.py @@ -336,11 +336,11 @@ class TestResumeWatermark: """Test ResumeWatermark serialization""" def test_to_json_full_data(self): - """Test serializing watermark with all fields""" + """Test serializing watermark to server format""" watermark = ResumeWatermark( ranges=[ - BlockRange(network='ethereum', start=100, end=200), - BlockRange(network='polygon', start=50, end=150), + BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'), + BlockRange(network='polygon', start=50, end=150, hash='0xdef456'), ], timestamp='2024-01-01T00:00:00Z', sequence=42, @@ -349,59 +349,41 @@ def test_to_json_full_data(self): json_str = watermark.to_json() data = json.loads(json_str) - assert len(data['ranges']) == 2 - assert data['ranges'][0]['network'] == 'ethereum' - assert data['timestamp'] == '2024-01-01T00:00:00Z' - assert data['sequence'] == 42 + assert len(data) == 2 + assert 'ethereum' in data + assert data['ethereum']['number'] == 200 + assert data['ethereum']['hash'] == '0xabc123' + assert 'polygon' in data + assert data['polygon']['number'] == 150 + assert data['polygon']['hash'] == '0xdef456' def test_to_json_minimal_data(self): """Test serializing watermark with only ranges""" - watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=100, end=200)]) + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=100, end=200, hash='0xabc123')]) json_str = watermark.to_json() data = json.loads(json_str) - assert len(data['ranges']) == 1 - assert 'timestamp' not in data - assert 'sequence' not in data - - def test_from_json_full_data(self): - """Test deserializing watermark with all fields""" - json_str = json.dumps( - { - 'ranges': [ - {'network': 'ethereum', 'start': 100, 'end': 200}, - {'network': 'polygon', 'start': 50, 'end': 150}, - ], - 'timestamp': '2024-01-01T00:00:00Z', - 'sequence': 42, - } - ) - - watermark = ResumeWatermark.from_json(json_str) - - assert len(watermark.ranges) == 2 - assert watermark.ranges[0].network == 'ethereum' - assert watermark.timestamp == '2024-01-01T00:00:00Z' - assert watermark.sequence == 42 + assert len(data) == 1 + assert 'ethereum' in data + assert data['ethereum']['number'] == 200 + assert data['ethereum']['hash'] == '0xabc123' - def test_round_trip_serialization(self): - """Test that serialization round-trip preserves data""" - original = ResumeWatermark( + def test_to_json_server_format(self): + watermark = ResumeWatermark( ranges=[ - BlockRange(network='ethereum', start=100, end=200), - BlockRange(network='polygon', start=50, end=150), + BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'), + BlockRange(network='polygon', start=50, end=150, hash='0xdef456'), ], - timestamp='2024-01-01T00:00:00Z', - sequence=42, ) - json_str = original.to_json() - restored = ResumeWatermark.from_json(json_str) + json_str = watermark.to_json() + data = json.loads(json_str) - assert len(restored.ranges) == len(original.ranges) - assert restored.timestamp == original.timestamp - assert restored.sequence == original.sequence + assert 'ethereum' in data + assert data['ethereum']['number'] == 200 + assert 'polygon' in data + assert data['polygon']['number'] == 150 @pytest.mark.unit