Real-Time Stream Processing at Scale with PySpark - Building Low-Latency Analytics for IoT Data

 

When manufacturing equipment starts failing, you have minutes (not hours) to catch it before it causes expensive damage. Our batch processing system was generating alerts 6 hours too late, so I reb...

When manufacturing equipment starts failing, you have minutes (not hours) to catch it before it causes expensive damage. Our batch processing system was generating alerts 6 hours too late, so I rebuilt the entire pipeline to process IoT sensor data in real-time.

The Problem: Expensive Failures

Our manufacturing client had a painful reality:

  • 50,000+ sensors across multiple facilities generating constant data
  • Equipment failures costing $100K+ per incident
  • Batch processing taking hours to detect problems
  • Critical alerts arriving after the damage was already done
  • Different sensor types with incompatible data formats
  • Traffic spikes during production shifts breaking our system

Batch ETL jobs weren’t cutting it. We needed real-time processing that could scale without breaking the bank.

PySpark Structured Streaming Architecture

Core Streaming Pipeline

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.streaming import StreamingQuery
import json
import time
from typing import Dict, List, Optional, Callable
import logging

class IoTStreamProcessor:
    def __init__(self, app_name: str = "IoTStreamProcessor"):
        """Initialize PySpark Structured Streaming processor"""
        
        self.spark = SparkSession.builder \
            .appName(app_name) \
            .config("spark.sql.streaming.checkpointLocation", "/tmp/checkpoint") \
            .config("spark.sql.streaming.stateStore.providerClass", 
                   "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") \
            .config("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true") \
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
            .getOrCreate()
            
        self.spark.sparkContext.setLogLevel("WARN")
        
        # Schema definitions for different sensor types
        self.sensor_schemas = self._define_sensor_schemas()
        
        # Stream queries registry
        self.active_queries: Dict[str, StreamingQuery] = {}
        
        # Metrics tracking
        self.processing_metrics = {
            'records_processed': 0,
            'processing_errors': 0,
            'avg_latency_ms': 0
        }
        
    def _define_sensor_schemas(self) -> Dict[str, StructType]:
        """Define schemas for different sensor types"""
        
        base_schema = StructType([
            StructField("sensor_id", StringType(), True),
            StructField("timestamp", TimestampType(), True),
            StructField("facility_id", StringType(), True),
            StructField("sensor_type", StringType(), True),
            StructField("quality_score", DoubleType(), True)
        ])
        
        return {
            "temperature": base_schema.add("temperature_celsius", DoubleType()).add("humidity", DoubleType()),
            "vibration": base_schema.add("acceleration_x", DoubleType()).add("acceleration_y", DoubleType()).add("acceleration_z", DoubleType()),
            "pressure": base_schema.add("pressure_psi", DoubleType()).add("flow_rate", DoubleType()),
            "power": base_schema.add("voltage", DoubleType()).add("current", DoubleType()).add("power_watts", DoubleType()),
            "generic": base_schema.add("value", DoubleType()).add("unit", StringType())
        }
        
    def create_kafka_stream(self, kafka_servers: str, topic: str, 
                           starting_offsets: str = "latest") -> DataFrame:
        """Create Kafka input stream"""
        
        kafka_stream = self.spark \
            .readStream \
            .format("kafka") \
            .option("kafka.bootstrap.servers", kafka_servers) \
            .option("subscribe", topic) \
            .option("startingOffsets", starting_offsets) \
            .option("failOnDataLoss", "false") \
            .option("kafka.consumer.session.timeout.ms", "30000") \
            .option("kafka.consumer.request.timeout.ms", "40000") \
            .option("kafka.consumer.max.poll.records", "10000") \
            .load()
            
        return kafka_stream
        
    def parse_sensor_data(self, raw_stream: DataFrame) -> DataFrame:
        """Parse and validate sensor data from Kafka stream"""
        
        # Convert Kafka value from binary to string
        json_stream = raw_stream.select(
            col("key").cast("string").alias("sensor_id"),
            col("value").cast("string").alias("json_data"),
            col("topic"),
            col("partition"),
            col("offset"),
            col("timestamp").alias("kafka_timestamp")
        )
        
        # Parse JSON data
        parsed_stream = json_stream.select(
            col("sensor_id"),
            col("topic"),
            col("partition"),
            col("offset"),
            col("kafka_timestamp"),
            from_json(col("json_data"), self.sensor_schemas["generic"]).alias("data")
        ).select(
            col("sensor_id"),
            col("kafka_timestamp"),
            col("data.*")
        )
        
        # Data validation and cleaning
        validated_stream = parsed_stream \
            .filter(col("sensor_id").isNotNull()) \
            .filter(col("timestamp").isNotNull()) \
            .filter(col("facility_id").isNotNull()) \
            .withColumn("processing_timestamp", current_timestamp()) \
            .withColumn("latency_ms", 
                       (unix_timestamp("processing_timestamp") - unix_timestamp("timestamp")) * 1000)
        
        return validated_stream
        
    def enrich_sensor_data(self, sensor_stream: DataFrame) -> DataFrame:
        """Enrich sensor data with facility and equipment metadata"""
        
        # Load facility metadata (this could be from a database or file)
        facility_metadata = self.spark.read.json("s3://metadata/facilities/") \
            .select("facility_id", "facility_name", "location", "timezone", "production_line")
            
        # Load sensor metadata
        sensor_metadata = self.spark.read.json("s3://metadata/sensors/") \
            .select("sensor_id", "equipment_id", "maintenance_schedule", "calibration_date", "alert_thresholds")
        
        # Enrich with facility data
        enriched_stream = sensor_stream \
            .join(broadcast(facility_metadata), "facility_id", "left_outer") \
            .join(broadcast(sensor_metadata), "sensor_id", "left_outer")
            
        # Add derived fields
        enriched_stream = enriched_stream \
            .withColumn("hour_of_day", hour("timestamp")) \
            .withColumn("day_of_week", dayofweek("timestamp")) \
            .withColumn("shift", 
                       when(col("hour_of_day").between(6, 14), "morning")
                       .when(col("hour_of_day").between(14, 22), "afternoon")
                       .otherwise("night"))
                       
        return enriched_stream

class RealTimeAnomalyDetector:
    def __init__(self, spark_session):
        self.spark = spark_session
        self.anomaly_models = {}
        self.alert_thresholds = {}
        
    def detect_statistical_anomalies(self, sensor_stream: DataFrame, 
                                   window_duration: str = "5 minutes",
                                   slide_duration: str = "1 minute") -> DataFrame:
        """Detect anomalies using statistical methods"""
        
        # Calculate rolling statistics
        windowed_stats = sensor_stream \
            .withWatermark("timestamp", "1 minute") \
            .groupBy(
                col("sensor_id"),
                col("facility_id"),
                col("sensor_type"),
                window(col("timestamp"), window_duration, slide_duration)
            ) \
            .agg(
                avg("value").alias("avg_value"),
                stddev("value").alias("stddev_value"),
                min("value").alias("min_value"),
                max("value").alias("max_value"),
                count("value").alias("record_count"),
                collect_list("value").alias("value_history")
            )
            
        # Calculate z-scores and detect outliers
        anomaly_detection = windowed_stats \
            .withColumn("z_score_threshold", lit(3.0)) \
            .withColumn("current_value", expr("value_history[size(value_history)-1]")) \
            .withColumn("z_score", 
                       abs(col("current_value") - col("avg_value")) / col("stddev_value")) \
            .withColumn("is_anomaly", 
                       col("z_score") > col("z_score_threshold")) \
            .withColumn("anomaly_severity",
                       when(col("z_score") > 5.0, "critical")
                       .when(col("z_score") > 4.0, "high")
                       .when(col("z_score") > 3.0, "medium")
                       .otherwise("normal"))
                       
        return anomaly_detection.filter(col("is_anomaly") == True)
        
    def detect_pattern_anomalies(self, sensor_stream: DataFrame) -> DataFrame:
        """Detect anomalies using pattern-based rules"""
        
        # Define pattern-based anomaly detection rules
        pattern_anomalies = sensor_stream \
            .withWatermark("timestamp", "30 seconds") \
            .groupBy(
                col("sensor_id"),
                col("equipment_id"),
                window(col("timestamp"), "2 minutes", "30 seconds")
            ) \
            .agg(
                collect_list(struct("timestamp", "value")).alias("readings"),
                count("value").alias("reading_count")
            ) \
            .withColumn("sorted_readings", sort_array("readings")) \
            .withColumn("value_trend", self._calculate_trend_udf(col("sorted_readings"))) \
            .withColumn("rapid_change", 
                       abs(col("value_trend")) > 10.0)  # Configurable threshold
            .withColumn("sensor_stuck", 
                       size(array_distinct(expr("transform(readings, x -> x.value)"))) == 1 
                       and col("reading_count") > 10)
                       
        return pattern_anomalies.filter(
            col("rapid_change") == True or col("sensor_stuck") == True
        )
        
    def _calculate_trend_udf(self):
        """UDF to calculate trend from sensor readings"""
        
        @udf(returnType=DoubleType())
        def calculate_trend(readings):
            if not readings or len(readings) < 2:
                return 0.0
                
            # Simple linear regression slope calculation
            n = len(readings)
            sum_x = sum(range(n))
            sum_y = sum(r['value'] for r in readings)
            sum_xy = sum(i * r['value'] for i, r in enumerate(readings))
            sum_x2 = sum(i * i for i in range(n))
            
            slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
            return float(slope)
            
        return calculate_trend

class StreamingMLPredictor:
    def __init__(self, spark_session, model_path: str):
        self.spark = spark_session
        self.model_path = model_path
        self.ml_models = {}
        
    def load_streaming_ml_model(self, model_name: str, sensor_type: str):
        """Load pre-trained ML model for streaming prediction"""
        
        from pyspark.ml import Pipeline
        from pyspark.ml.feature import VectorAssembler
        from pyspark.ml.regression import RandomForestRegressor
        
        # Load saved model
        model = Pipeline.load(f"{self.model_path}/{model_name}")
        
        self.ml_models[sensor_type] = model
        
    def predict_equipment_health(self, sensor_stream: DataFrame) -> DataFrame:
        """Predict equipment health using ML models"""
        
        # Feature engineering for ML prediction
        feature_stream = sensor_stream \
            .withColumn("hour_sin", sin(2 * pi() * col("hour_of_day") / 24)) \
            .withColumn("hour_cos", cos(2 * pi() * col("hour_of_day") / 24)) \
            .withColumn("day_sin", sin(2 * pi() * col("day_of_week") / 7)) \
            .withColumn("day_cos", cos(2 * pi() * col("day_of_week") / 7))
            
        # Create feature vector
        feature_cols = ["value", "quality_score", "hour_sin", "hour_cos", "day_sin", "day_cos"]
        assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
        
        feature_vector_stream = assembler.transform(feature_stream)
        
        # Apply ML model for prediction (simplified - in practice, different models for different sensor types)
        if "temperature" in self.ml_models:
            model = self.ml_models["temperature"]
            predictions = model.transform(feature_vector_stream)
            
            # Add risk assessment
            risk_predictions = predictions \
                .withColumn("health_score", col("prediction")) \
                .withColumn("risk_level",
                           when(col("health_score") < 0.3, "critical")
                           .when(col("health_score") < 0.6, "warning")
                           .otherwise("normal")) \
                .withColumn("maintenance_recommended", 
                           col("health_score") < 0.4)
                           
            return risk_predictions
            
        return feature_vector_stream

class StreamingAggregator:
    def __init__(self, spark_session):
        self.spark = spark_session
        
    def create_real_time_dashboards(self, sensor_stream: DataFrame) -> List[StreamingQuery]:
        """Create real-time aggregations for dashboards"""
        
        queries = []
        
        # 1. Facility-level metrics (1-minute windows)
        facility_metrics = sensor_stream \
            .withWatermark("timestamp", "2 minutes") \
            .groupBy(
                col("facility_id"),
                col("facility_name"),
                window(col("timestamp"), "1 minute", "30 seconds")
            ) \
            .agg(
                count("sensor_id").alias("active_sensors"),
                avg("value").alias("avg_sensor_reading"),
                countDistinct("equipment_id").alias("active_equipment"),
                sum(when(col("quality_score") < 0.8, 1).otherwise(0)).alias("quality_alerts")
            ) \
            .withColumn("timestamp", col("window.end")) \
            .drop("window")
            
        # Write to time-series database
        facility_query = facility_metrics.writeStream \
            .format("org.apache.spark.sql.redis") \
            .option("table", "facility_metrics") \
            .option("key.column", "facility_id") \
            .outputMode("update") \
            .trigger(processingTime="30 seconds") \
            .start()
            
        queries.append(facility_query)
        
        # 2. Equipment health summary (5-minute windows)
        equipment_health = sensor_stream \
            .withWatermark("timestamp", "5 minutes") \
            .groupBy(
                col("equipment_id"),
                col("facility_id"),
                col("sensor_type"),
                window(col("timestamp"), "5 minutes", "1 minute")
            ) \
            .agg(
                avg("value").alias("avg_reading"),
                max("value").alias("peak_reading"),
                stddev("value").alias("reading_variability"),
                avg("quality_score").alias("avg_quality"),
                count("*").alias("reading_count")
            ) \
            .withColumn("equipment_status",
                       when(col("avg_quality") < 0.7, "degraded")
                       .when(col("reading_variability") > 10.0, "unstable")
                       .otherwise("healthy")) \
            .withColumn("timestamp", col("window.end")) \
            .drop("window")
            
        # Write to Elasticsearch for visualization
        equipment_query = equipment_health.writeStream \
            .format("org.elasticsearch.spark.sql") \
            .option("es.nodes", "elasticsearch-cluster:9200") \
            .option("es.resource", "equipment_health/data") \
            .option("es.mapping.timestamp", "timestamp") \
            .outputMode("append") \
            .trigger(processingTime="1 minute") \
            .start()
            
        queries.append(equipment_query)
        
        # 3. Production line efficiency (10-minute windows)
        production_efficiency = sensor_stream \
            .filter(col("sensor_type").isin(["power", "vibration", "temperature"])) \
            .withWatermark("timestamp", "10 minutes") \
            .groupBy(
                col("production_line"),
                col("facility_id"),
                col("shift"),
                window(col("timestamp"), "10 minutes", "5 minutes")
            ) \
            .agg(
                avg(when(col("sensor_type") == "power", col("value"))).alias("avg_power_consumption"),
                avg(when(col("sensor_type") == "vibration", col("value"))).alias("avg_vibration"),
                avg(when(col("sensor_type") == "temperature", col("value"))).alias("avg_temperature"),
                countDistinct("equipment_id").alias("active_equipment"),
                sum(when(col("quality_score") > 0.95, 1).otherwise(0)).alias("optimal_readings")
            ) \
            .withColumn("efficiency_score", 
                       col("optimal_readings") / (col("active_equipment") * 10)) \
            .withColumn("timestamp", col("window.end")) \
            .drop("window")
            
        # Write to data warehouse for reporting
        efficiency_query = production_efficiency.writeStream \
            .format("delta") \
            .option("path", "s3://datalake/production_efficiency/") \
            .option("checkpointLocation", "s3://checkpoints/production_efficiency/") \
            .outputMode("append") \
            .trigger(processingTime="5 minutes") \
            .start()
            
        queries.append(efficiency_query)
        
        return queries

class AlertingSystem:
    def __init__(self, spark_session):
        self.spark = spark_session
        self.alert_channels = {
            'slack': SlackAlerter(),
            'email': EmailAlerter(),
            'sms': SMSAlerter(),
            'webhook': WebhookAlerter()
        }
        
    def setup_critical_alerts(self, anomaly_stream: DataFrame) -> StreamingQuery:
        """Setup critical alert processing"""
        
        # Filter for critical anomalies
        critical_alerts = anomaly_stream \
            .filter(col("anomaly_severity") == "critical") \
            .withColumn("alert_id", expr("uuid()")) \
            .withColumn("alert_timestamp", current_timestamp()) \
            .withColumn("alert_type", lit("critical_anomaly")) \
            .select(
                col("alert_id"),
                col("alert_timestamp"),
                col("sensor_id"),
                col("equipment_id"),
                col("facility_id"),
                col("facility_name"),
                col("anomaly_severity"),
                col("z_score"),
                col("current_value"),
                struct("avg_value", "stddev_value").alias("baseline_stats")
            )
            
        # Process alerts with custom logic
        processed_alerts = critical_alerts \
            .withColumn("alert_message", self._create_alert_message_udf(
                col("sensor_id"), col("equipment_id"), col("facility_name"),
                col("current_value"), col("z_score")
            )) \
            .withColumn("priority_score", self._calculate_priority_udf(
                col("facility_id"), col("equipment_id"), col("z_score")
            ))
            
        # Send alerts through multiple channels
        alert_query = processed_alerts.writeStream \
            .foreach(AlertSender(self.alert_channels)) \
            .outputMode("append") \
            .trigger(processingTime="10 seconds") \
            .start()
            
        return alert_query
        
    def _create_alert_message_udf(self):
        @udf(returnType=StringType())
        def create_message(sensor_id, equipment_id, facility_name, current_value, z_score):
            return f"""
            🚨 CRITICAL ANOMALY DETECTED
            Facility: {facility_name}
            Equipment: {equipment_id}
            Sensor: {sensor_id}
            Current Value: {current_value:.2f}
            Anomaly Score: {z_score:.2f}
            Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}
            
            Immediate investigation required!
            """
        return create_message
        
    def _calculate_priority_udf(self):
        @udf(returnType=IntegerType())
        def calculate_priority(facility_id, equipment_id, z_score):
            # Priority scoring logic
            base_score = int(z_score * 10)
            
            # Critical equipment gets higher priority
            if "reactor" in equipment_id.lower() or "turbine" in equipment_id.lower():
                base_score += 50
                
            # Critical facilities get higher priority
            if "main" in facility_id.lower():
                base_score += 25
                
            return min(base_score, 100)  # Cap at 100
            
        return calculate_priority

class AlertSender:
    def __init__(self, alert_channels):
        self.alert_channels = alert_channels
        
    def process(self, row):
        """Process individual alert row"""
        
        alert_data = row.asDict()
        priority = alert_data['priority_score']
        
        # Send through appropriate channels based on priority
        if priority >= 90:
            # Critical - send through all channels
            for channel in self.alert_channels.values():
                channel.send_alert(alert_data)
        elif priority >= 70:
            # High - send through primary channels
            self.alert_channels['slack'].send_alert(alert_data)
            self.alert_channels['email'].send_alert(alert_data)
        else:
            # Medium - send through monitoring channel
            self.alert_channels['slack'].send_alert(alert_data)

# Main streaming application
class IoTStreamingApplication:
    def __init__(self):
        self.processor = IoTStreamProcessor()
        self.anomaly_detector = RealTimeAnomalyDetector(self.processor.spark)
        self.ml_predictor = StreamingMLPredictor(self.processor.spark, "s3://models/")
        self.aggregator = StreamingAggregator(self.processor.spark)
        self.alerting = AlertingSystem(self.processor.spark)
        
        self.active_queries = []
        
    def run_streaming_pipeline(self, kafka_servers: str, input_topic: str):
        """Run the complete IoT streaming pipeline"""
        
        try:
            # 1. Create input stream
            raw_kafka_stream = self.processor.create_kafka_stream(kafka_servers, input_topic)
            
            # 2. Parse and validate data
            parsed_stream = self.processor.parse_sensor_data(raw_kafka_stream)
            
            # 3. Enrich with metadata
            enriched_stream = self.processor.enrich_sensor_data(parsed_stream)
            
            # 4. Anomaly detection
            statistical_anomalies = self.anomaly_detector.detect_statistical_anomalies(enriched_stream)
            pattern_anomalies = self.anomaly_detector.detect_pattern_anomalies(enriched_stream)
            
            # Combine anomaly streams
            all_anomalies = statistical_anomalies.union(pattern_anomalies)
            
            # 5. ML-based predictions
            health_predictions = self.ml_predictor.predict_equipment_health(enriched_stream)
            
            # 6. Real-time aggregations
            dashboard_queries = self.aggregator.create_real_time_dashboards(enriched_stream)
            self.active_queries.extend(dashboard_queries)
            
            # 7. Critical alerting
            alert_query = self.alerting.setup_critical_alerts(all_anomalies)
            self.active_queries.append(alert_query)
            
            # 8. Raw data archival
            archive_query = enriched_stream.writeStream \
                .format("delta") \
                .option("path", "s3://datalake/sensor_data/") \
                .option("checkpointLocation", "s3://checkpoints/raw_data/") \
                .partitionBy("facility_id", "sensor_type") \
                .outputMode("append") \
                .trigger(processingTime="1 minute") \
                .start()
                
            self.active_queries.append(archive_query)
            
            # 9. Start monitoring
            self._start_monitoring()
            
            # Wait for termination
            for query in self.active_queries:
                query.awaitTermination()
                
        except Exception as e:
            logging.error(f"Streaming pipeline failed: {e}")
            self.stop_all_queries()
            raise
            
    def _start_monitoring(self):
        """Start monitoring thread for streaming queries"""
        
        import threading
        
        def monitor_queries():
            while True:
                try:
                    for i, query in enumerate(self.active_queries):
                        if not query.isActive:
                            logging.error(f"Query {i} is not active: {query.lastProgress}")
                        else:
                            progress = query.lastProgress
                            if progress:
                                logging.info(f"Query {i} progress: "
                                           f"inputRowsPerSecond={progress.get('inputRowsPerSecond', 0):.2f}, "
                                           f"batchDuration={progress.get('batchDuration', 0)}ms")
                                           
                    time.sleep(30)  # Monitor every 30 seconds
                    
                except Exception as e:
                    logging.error(f"Monitoring error: {e}")
                    time.sleep(60)
                    
        monitor_thread = threading.Thread(target=monitor_queries, daemon=True)
        monitor_thread.start()
        
    def stop_all_queries(self):
        """Stop all active streaming queries"""
        
        for query in self.active_queries:
            if query.isActive:
                query.stop()
                
        logging.info(f"Stopped {len(self.active_queries)} streaming queries")

# Performance optimization utilities
class StreamingOptimizer:
    @staticmethod
    def optimize_kafka_consumer(spark_session):
        """Optimize Kafka consumer configuration"""
        
        spark_session.conf.set("spark.sql.streaming.kafka.consumer.cache.capacity", "1000")
        spark_session.conf.set("spark.sql.streaming.kafka.consumer.poll.ms", "1000")
        spark_session.conf.set("spark.sql.streaming.kafka.consumer.fetchOffset.retryIntervalMs", "100")
        
    @staticmethod
    def optimize_checkpoint_storage(checkpoint_location: str):
        """Optimize checkpoint storage for performance"""
        
        # Use RocksDB state store for better performance
        checkpoint_config = {
            "spark.sql.streaming.stateStore.providerClass": 
                "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
            "spark.sql.streaming.stateStore.rocksdb.compactOnCommit": "true",
            "spark.sql.streaming.stateStore.rocksdb.blockSizeKB": "32",
            "spark.sql.streaming.stateStore.rocksdb.blockCacheSizeMB": "256"
        }
        
        return checkpoint_config
        
    @staticmethod
    def optimize_resource_allocation():
        """Optimize Spark resource allocation for streaming"""
        
        optimization_config = {
            # Dynamic allocation
            "spark.dynamicAllocation.enabled": "true",
            "spark.dynamicAllocation.minExecutors": "2",
            "spark.dynamicAllocation.maxExecutors": "20",
            "spark.dynamicAllocation.initialExecutors": "4",
            
            # Memory configuration
            "spark.executor.memory": "8g",
            "spark.executor.memoryFraction": "0.8",
            "spark.executor.memoryOffHeap.enabled": "true",
            "spark.executor.memoryOffHeap.size": "2g",
            
            # CPU configuration
            "spark.executor.cores": "4",
            "spark.task.cpus": "1",
            
            # Adaptive query execution
            "spark.sql.adaptive.enabled": "true",
            "spark.sql.adaptive.coalescePartitions.enabled": "true",
            "spark.sql.adaptive.coalescePartitions.minPartitionNum": "1",
            
            # Serialization
            "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
            "spark.kryo.registrationRequired": "false"
        }
        
        return optimization_config

Advanced Stream Processing Patterns

Exactly-Once Processing with Idempotent Writes

class ExactlyOnceProcessor:
    def __init__(self, spark_session):
        self.spark = spark_session
        
    def setup_idempotent_writes(self, stream_df: DataFrame, sink_table: str) -> StreamingQuery:
        """Setup exactly-once processing with idempotent writes"""
        
        # Add unique identifier for each record
        idempotent_stream = stream_df \
            .withColumn("record_id", 
                       concat_ws("-", 
                                col("sensor_id"), 
                                col("timestamp").cast("long"),
                                col("facility_id"))) \
            .withColumn("processing_time", current_timestamp())
            
        # Use Delta Lake for ACID transactions
        query = idempotent_stream.writeStream \
            .format("delta") \
            .option("path", f"s3://datalake/{sink_table}/") \
            .option("checkpointLocation", f"s3://checkpoints/{sink_table}/") \
            .option("mergeSchema", "true") \
            .outputMode("append") \
            .trigger(processingTime="30 seconds") \
            .start()
            
        return query
        
    def handle_late_arrivals(self, stream_df: DataFrame, watermark_delay: str = "5 minutes") -> DataFrame:
        """Handle late arriving data with watermarks"""
        
        # Set watermark for late data handling
        watermarked_stream = stream_df \
            .withWatermark("timestamp", watermark_delay)
            
        # Create time-based aggregations that handle late arrivals
        windowed_aggregates = watermarked_stream \
            .groupBy(
                window(col("timestamp"), "1 minute", "30 seconds"),
                col("sensor_id")
            ) \
            .agg(
                avg("value").alias("avg_value"),
                count("*").alias("record_count"),
                max("timestamp").alias("latest_timestamp")
            ) \
            .withColumn("window_start", col("window.start")) \
            .withColumn("window_end", col("window.end")) \
            .drop("window")
            
        return windowed_aggregates

class StreamJoinProcessor:
    def __init__(self, spark_session):
        self.spark = spark_session
        
    def stream_to_stream_join(self, primary_stream: DataFrame, 
                             secondary_stream: DataFrame,
                             join_key: str,
                             watermark_delay: str = "2 minutes") -> DataFrame:
        """Perform stream-to-stream join with time constraints"""
        
        # Add watermarks to both streams
        primary_watermarked = primary_stream \
            .withWatermark("timestamp", watermark_delay)
            
        secondary_watermarked = secondary_stream \
            .withWatermark("timestamp", watermark_delay)
            
        # Perform time-bounded stream join
        joined_stream = primary_watermarked.alias("p") \
            .join(
                secondary_watermarked.alias("s"),
                expr(f"""
                    p.{join_key} = s.{join_key} AND
                    p.timestamp >= s.timestamp AND
                    p.timestamp <= s.timestamp + interval 1 minute
                """)
            ) \
            .select("p.*", "s.value".alias("secondary_value"))
            
        return joined_stream
        
    def stream_to_static_join(self, stream_df: DataFrame, 
                             static_table_path: str,
                             join_key: str) -> DataFrame:
        """Join streaming data with static lookup table"""
        
        # Load static data
        static_df = self.spark.read.parquet(static_table_path)
        
        # Broadcast join for performance
        joined_stream = stream_df \
            .join(broadcast(static_df), join_key, "left_outer")
            
        return joined_stream

class ComplexEventProcessor:
    def __init__(self, spark_session):
        self.spark = spark_session
        
    def detect_event_patterns(self, sensor_stream: DataFrame) -> DataFrame:
        """Detect complex event patterns using session windows"""
        
        # Group events into sessions based on inactivity gaps
        sessionized_events = sensor_stream \
            .withWatermark("timestamp", "30 seconds") \
            .groupBy(
                col("equipment_id"),
                session_window(col("timestamp"), "5 minutes")  # 5-minute inactivity gap
            ) \
            .agg(
                collect_list(struct("timestamp", "sensor_id", "value")).alias("events"),
                min("timestamp").alias("session_start"),
                max("timestamp").alias("session_end"),
                count("*").alias("event_count")
            )
            
        # Analyze event patterns within sessions
        pattern_detected = sessionized_events \
            .withColumn("session_duration", 
                       (unix_timestamp("session_end") - unix_timestamp("session_start")) / 60) \
            .withColumn("avg_event_rate", col("event_count") / col("session_duration")) \
            .withColumn("equipment_startup_pattern",
                       col("avg_event_rate") > 10 and col("session_duration") < 2) \
            .withColumn("equipment_shutdown_pattern",
                       col("avg_event_rate") < 1 and col("session_duration") > 10)
                       
        return pattern_detected.filter(
            col("equipment_startup_pattern") == True or 
            col("equipment_shutdown_pattern") == True
        )

class StateManagement:
    def __init__(self, spark_session):
        self.spark = spark_session
        
    def maintain_equipment_state(self, sensor_stream: DataFrame) -> DataFrame:
        """Maintain stateful equipment health tracking"""
        
        from pyspark.sql.streaming import GroupState, GroupStateTimeout
        from pyspark.sql.types import *
        
        # Define state schema
        state_schema = StructType([
            StructField("equipment_id", StringType(), True),
            StructField("last_maintenance", TimestampType(), True),
            StructField("total_runtime_hours", DoubleType(), True),
            StructField("fault_count", IntegerType(), True),
            StructField("last_update", TimestampType(), True)
        ])
        
        # Update state function
        def update_equipment_state(key, values, state):
            if state.hasTimedOut:
                # Handle timeout - equipment might be offline
                return None
                
            # Get current state or initialize
            current_state = state.get if state.exists else {
                'equipment_id': key['equipment_id'],
                'last_maintenance': None,
                'total_runtime_hours': 0.0,
                'fault_count': 0,
                'last_update': None
            }
            
            # Process new sensor readings
            for reading in values:
                current_state['total_runtime_hours'] += 0.1  # Assume 0.1 hour per reading
                
                if reading['quality_score'] < 0.5:
                    current_state['fault_count'] += 1
                    
                current_state['last_update'] = reading['timestamp']
                
            # Update state
            state.update(current_state)
            state.setTimeoutDuration("10 minutes")  # Timeout if no data for 10 minutes
            
            return current_state
            
        # Apply stateful operation
        stateful_stream = sensor_stream \
            .groupByKey(lambda x: x['equipment_id']) \
            .mapGroupsWithState(
                update_equipment_state,
                state_schema,
                timeout=GroupStateTimeout.ProcessingTimeTimeout
            )
            
        return stateful_stream

Performance Monitoring and Optimization

Comprehensive Metrics Collection

import psutil
import time
from threading import Thread
from dataclasses import dataclass, asdict
from typing import List, Dict

@dataclass
class StreamingMetrics:
    timestamp: float
    input_rate: float
    processing_rate: float
    batch_duration_ms: int
    scheduling_delay_ms: int
    processing_delay_ms: int
    total_delay_ms: int
    num_input_rows: int
    num_processed_rows: int
    active_batches: int
    memory_usage_mb: float
    cpu_usage_percent: float
    throughput_records_per_second: float

class StreamingMonitor:
    def __init__(self, streaming_queries: List[StreamingQuery]):
        self.queries = streaming_queries
        self.metrics_history = []
        self.monitoring_active = False
        
    def start_monitoring(self, collection_interval: int = 30):
        """Start metrics collection"""
        
        self.monitoring_active = True
        
        def collect_metrics():
            while self.monitoring_active:
                try:
                    current_metrics = self._collect_current_metrics()
                    self.metrics_history.append(current_metrics)
                    
                    # Keep only last 1000 metrics
                    if len(self.metrics_history) > 1000:
                        self.metrics_history = self.metrics_history[-1000:]
                        
                    # Log key metrics
                    self._log_metrics(current_metrics)
                    
                    # Check for performance issues
                    self._check_performance_alerts(current_metrics)
                    
                    time.sleep(collection_interval)
                    
                except Exception as e:
                    logging.error(f"Metrics collection error: {e}")
                    time.sleep(collection_interval)
                    
        monitor_thread = Thread(target=collect_metrics, daemon=True)
        monitor_thread.start()
        
    def _collect_current_metrics(self) -> StreamingMetrics:
        """Collect current streaming metrics"""
        
        # Aggregate metrics from all queries
        total_input_rate = 0.0
        total_processing_rate = 0.0
        total_batch_duration = 0
        total_delay = 0
        total_input_rows = 0
        total_processed_rows = 0
        active_batches = 0
        
        for query in self.queries:
            if query.isActive:
                progress = query.lastProgress
                
                if progress:
                    total_input_rate += progress.get('inputRowsPerSecond', 0)
                    total_processing_rate += progress.get('processingRowsPerSecond', 0)
                    total_batch_duration += progress.get('batchDuration', 0)
                    total_delay += progress.get('durationMs', {}).get('triggerExecution', 0)
                    
                    sources = progress.get('sources', [])
                    for source in sources:
                        total_input_rows += source.get('inputRowsPerSecond', 0)
                        total_processed_rows += source.get('processedRowsPerSecond', 0)
                        
                    active_batches += 1
                    
        # System metrics
        memory_usage = psutil.virtual_memory().used / (1024 * 1024)  # MB
        cpu_usage = psutil.cpu_percent(interval=1)
        
        # Calculate throughput
        throughput = total_processing_rate if active_batches > 0 else 0
        
        return StreamingMetrics(
            timestamp=time.time(),
            input_rate=total_input_rate,
            processing_rate=total_processing_rate,
            batch_duration_ms=total_batch_duration // max(active_batches, 1),
            scheduling_delay_ms=total_delay // max(active_batches, 1),
            processing_delay_ms=0,  # Would need more detailed calculation
            total_delay_ms=total_delay // max(active_batches, 1),
            num_input_rows=int(total_input_rows),
            num_processed_rows=int(total_processed_rows),
            active_batches=active_batches,
            memory_usage_mb=memory_usage,
            cpu_usage_percent=cpu_usage,
            throughput_records_per_second=throughput
        )
        
    def _log_metrics(self, metrics: StreamingMetrics):
        """Log key performance metrics"""
        
        logging.info(f"Streaming Metrics - "
                    f"Input Rate: {metrics.input_rate:.1f}/s, "
                    f"Processing Rate: {metrics.processing_rate:.1f}/s, "
                    f"Batch Duration: {metrics.batch_duration_ms}ms, "
                    f"Memory: {metrics.memory_usage_mb:.1f}MB, "
                    f"CPU: {metrics.cpu_usage_percent:.1f}%")
                    
    def _check_performance_alerts(self, metrics: StreamingMetrics):
        """Check for performance issues and alert"""
        
        alerts = []
        
        # Check processing lag
        if metrics.input_rate > metrics.processing_rate * 1.2:
            alerts.append(f"Processing lag detected: input={metrics.input_rate:.1f}/s > processing={metrics.processing_rate:.1f}/s")
            
        # Check batch duration
        if metrics.batch_duration_ms > 30000:  # 30 seconds
            alerts.append(f"High batch duration: {metrics.batch_duration_ms}ms")
            
        # Check memory usage
        if metrics.memory_usage_mb > 8192:  # 8GB
            alerts.append(f"High memory usage: {metrics.memory_usage_mb:.1f}MB")
            
        # Check CPU usage
        if metrics.cpu_usage_percent > 90:
            alerts.append(f"High CPU usage: {metrics.cpu_usage_percent:.1f}%")
            
        for alert in alerts:
            logging.warning(f"PERFORMANCE ALERT: {alert}")
            
    def get_performance_summary(self, last_n_minutes: int = 10) -> Dict:
        """Get performance summary for last N minutes"""
        
        if not self.metrics_history:
            return {}
            
        cutoff_time = time.time() - (last_n_minutes * 60)
        recent_metrics = [m for m in self.metrics_history if m.timestamp >= cutoff_time]
        
        if not recent_metrics:
            return {}
            
        # Calculate averages
        avg_input_rate = sum(m.input_rate for m in recent_metrics) / len(recent_metrics)
        avg_processing_rate = sum(m.processing_rate for m in recent_metrics) / len(recent_metrics)
        avg_batch_duration = sum(m.batch_duration_ms for m in recent_metrics) / len(recent_metrics)
        avg_memory_usage = sum(m.memory_usage_mb for m in recent_metrics) / len(recent_metrics)
        
        # Calculate percentiles
        batch_durations = sorted([m.batch_duration_ms for m in recent_metrics])
        p95_batch_duration = batch_durations[int(0.95 * len(batch_durations))]
        
        return {
            'time_window_minutes': last_n_minutes,
            'avg_input_rate': avg_input_rate,
            'avg_processing_rate': avg_processing_rate,
            'avg_batch_duration_ms': avg_batch_duration,
            'p95_batch_duration_ms': p95_batch_duration,
            'avg_memory_usage_mb': avg_memory_usage,
            'processing_lag': max(0, avg_input_rate - avg_processing_rate),
            'total_records_processed': sum(m.num_processed_rows for m in recent_metrics),
            'active_batches': recent_metrics[-1].active_batches if recent_metrics else 0
        }
        
    def stop_monitoring(self):
        """Stop metrics collection"""
        self.monitoring_active = False

Production Impact

The real-time streaming analytics platform delivered substantial value across multiple dimensions:

System Performance

The platform achieved excellent performance characteristics:

  • High Throughput: Sustained processing of large volumes of sensor data
  • Low Latency: Fast end-to-end processing for critical alerting
  • High Reliability: Excellent uptime with exactly-once processing guarantees
  • Cost Effectiveness: Significant infrastructure cost reduction while improving performance

Business Benefits

The real-time insights enabled significant operational improvements:

  • Equipment Reliability: Substantial reduction in unplanned downtime
  • Maintenance Optimization: Predictive maintenance reducing costs
  • Energy Efficiency: Improved energy utilization across facilities
  • Product Quality: Enhanced quality control through real-time monitoring
  • Safety: Dramatically improved safety incident prevention

Scale and Architecture

Production deployment across multiple manufacturing facilities with robust infrastructure supporting comprehensive monitoring and alerting capabilities.

Advanced Optimization Techniques

Adaptive Batching and Resource Management

class AdaptiveResourceManager:
    def __init__(self, spark_session):
        self.spark = spark_session
        self.current_load_metrics = {}
        self.resource_adjustments = []
        
    def adjust_batch_intervals(self, current_throughput: float, target_latency_ms: int):
        """Dynamically adjust batch intervals based on throughput"""
        
        if current_throughput > 10000:  # High throughput
            optimal_interval = "15 seconds"
        elif current_throughput > 5000:  # Medium throughput
            optimal_interval = "30 seconds"
        else:  # Low throughput
            optimal_interval = "1 minute"
            
        # Update trigger intervals for active queries
        for query in self.active_queries:
            if query.isActive:
                # Note: In practice, you'd need to restart queries with new triggers
                logging.info(f"Recommended batch interval for high throughput: {optimal_interval}")
                
    def scale_cluster_resources(self, processing_lag: float, memory_pressure: float):
        """Auto-scale cluster resources based on performance metrics"""
        
        scale_decision = {
            'timestamp': time.time(),
            'processing_lag': processing_lag,
            'memory_pressure': memory_pressure,
            'action': 'none'
        }
        
        if processing_lag > 2.0 and memory_pressure > 0.8:
            # Scale up - more executors and memory
            scale_decision['action'] = 'scale_up_aggressive'
            self._request_scale_up(executor_increase=5, memory_increase_gb=16)
            
        elif processing_lag > 1.5:
            # Scale up - more executors
            scale_decision['action'] = 'scale_up_moderate'
            self._request_scale_up(executor_increase=2, memory_increase_gb=8)
            
        elif processing_lag < 0.5 and memory_pressure < 0.4:
            # Scale down - reduce costs
            scale_decision['action'] = 'scale_down'
            self._request_scale_down(executor_decrease=1)
            
        self.resource_adjustments.append(scale_decision)
        
    def _request_scale_up(self, executor_increase: int, memory_increase_gb: int):
        """Request cluster scale up"""
        
        current_executors = int(self.spark.conf.get("spark.dynamicAllocation.maxExecutors", "10"))
        new_max_executors = min(current_executors + executor_increase, 50)  # Cap at 50
        
        self.spark.conf.set("spark.dynamicAllocation.maxExecutors", str(new_max_executors))
        
        logging.info(f"Scaling up: max executors {current_executors} -> {new_max_executors}")
        
    def _request_scale_down(self, executor_decrease: int):
        """Request cluster scale down"""
        
        current_executors = int(self.spark.conf.get("spark.dynamicAllocation.maxExecutors", "10"))
        new_max_executors = max(current_executors - executor_decrease, 2)  # Minimum 2
        
        self.spark.conf.set("spark.dynamicAllocation.maxExecutors", str(new_max_executors))
        
        logging.info(f"Scaling down: max executors {current_executors} -> {new_max_executors}")

class IntelligentPartitioning:
    def __init__(self, spark_session):
        self.spark = spark_session
        
    def optimize_stream_partitioning(self, stream_df: DataFrame, 
                                   target_partition_size_mb: int = 128) -> DataFrame:
        """Optimize stream partitioning based on data characteristics"""
        
        # Analyze data distribution
        partition_analysis = stream_df.select("facility_id").distinct().count()
        
        if partition_analysis > 100:
            # High cardinality - use hash partitioning
            optimized_stream = stream_df.repartition("facility_id", "sensor_type")
        elif partition_analysis > 10:
            # Medium cardinality - moderate partitioning
            optimized_stream = stream_df.repartition(20, "facility_id")
        else:
            # Low cardinality - time-based partitioning
            optimized_stream = stream_df.repartition(
                10, 
                date_format("timestamp", "yyyy-MM-dd-HH")
            )
            
        return optimized_stream
        
    def implement_custom_partitioner(self, stream_df: DataFrame) -> DataFrame:
        """Implement custom partitioning logic for IoT data"""
        
        # Custom partitioner that considers both facility and time
        partitioned_stream = stream_df \
            .withColumn("partition_key", 
                       concat(
                           col("facility_id"),
                           lit("_"),
                           date_format("timestamp", "HH")  # Hour-based partitioning
                       )) \
            .repartition("partition_key")
            
        return partitioned_stream

Lessons Learned

1. Watermark Management is Critical

Proper watermark configuration balances late data handling with memory usage. Too aggressive watermarks lose late data; too lenient watermarks cause memory issues.

2. State Store Optimization Matters

RocksDB state stores provide much better performance than the default HDFS-based stores for stateful operations, especially with frequent updates.

3. Kafka Consumer Configuration is Key

Default Kafka consumer settings don’t work well at scale. Tuning poll intervals, batch sizes, and connection pooling dramatically improves throughput.

4. Memory Management Requires Attention

Streaming applications have different memory patterns than batch jobs. Off-heap storage and careful caching strategies prevent OOM errors.

5. Monitoring Must Be Real-Time

Batch-based monitoring isn’t sufficient for streaming applications. Real-time metrics and alerts are essential for maintaining SLA performance.

Future Enhancements

  • Delta Live Tables: Migration to Delta Live Tables for simplified pipeline management
  • Machine Learning Integration: Real-time feature engineering and model serving
  • Multi-Cloud Deployment: Cross-cloud streaming for disaster recovery
  • Edge Computing: Push processing closer to IoT sensors for ultra-low latency

Building a production-scale streaming platform taught us that success isn’t just about handling high throughput - it’s about building systems that maintain consistency, provide exactly-once guarantees, and deliver actionable insights within strict latency bounds. The key insight: streaming architecture must be designed holistically, considering data patterns, resource constraints, and business requirements from day one.

Our PySpark Structured Streaming platform transformed raw IoT sensor data into real-time business value, enabling predictive maintenance, energy optimization, and safety improvements that directly impacted the bottom line of manufacturing operations.