# Pickle Deserialization via `pickle.loads` on Untrusted Data

Language: Python
Severity: Critical
CWE: CWE-502

## Source
9

## Flow
9-11-12

## Sink
12

## Vulnerable Code
```python
import pickle
import base64
from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route('/ai/model/restore', methods=['POST'])
def restore_ml_checkpoint():
    checkpoint_b64 = request.json.get('checkpoint_data')
    if not checkpoint_b64:
        return jsonify({'error': 'No checkpoint provided'}), 400
    serialized_state = base64.b64decode(checkpoint_b64)
    model_state = pickle.loads(serialized_state)
    return jsonify({'status': 'Model restored', 'layers': len(model_state.get('layers', []))})

if __name__ == '__main__':
    app.run(debug=True)
```

## Explanation

The application deserializes untrusted pickle data from user input without validation. Python's pickle.loads() can execute arbitrary code during deserialization, allowing attackers to achieve remote code execution by crafting malicious pickle payloads.

## Remediation

The fix replaces `pickle.loads()` with `json.loads()` for deserialization, eliminating the arbitrary code execution risk since JSON only supports safe primitive data types. Additionally, HMAC signature verification is added to ensure checkpoint data integrity and authenticity before processing, and input validation enforces an expected schema on the deserialized data.

## Secure Code
```python
import json
import base64
import hmac
import hashlib
from flask import Flask, request, jsonify

app = Flask(__name__)

# In production, load this from a secure secret management system
CHECKPOINT_SIGNING_SECRET = app.config.get('CHECKPOINT_SECRET', 'change-this-to-a-secure-random-secret')

def verify_checkpoint_signature(data_bytes, provided_signature):
    """Verify HMAC signature of checkpoint data to ensure integrity."""
    expected_signature = hmac.new(
        CHECKPOINT_SIGNING_SECRET.encode(),
        data_bytes,
        hashlib.sha256
    ).hexdigest()
    return hmac.compare_digest(expected_signature, provided_signature)

def safe_deserialize_checkpoint(serialized_data):
    """Deserialize checkpoint data using JSON instead of pickle.
    
    Only allows safe data types: dicts, lists, strings, numbers, booleans, None.
    """
    try:
        model_state = json.loads(serialized_data)
    except (json.JSONDecodeError, UnicodeDecodeError) as e:
        raise ValueError(f"Invalid checkpoint format: {e}")
    
    if not isinstance(model_state, dict):
        raise ValueError("Checkpoint must be a JSON object at the top level")
    
    # Validate expected structure
    allowed_keys = {'layers', 'weights', 'optimizer_state', 'epoch', 'loss', 'metadata'}
    unexpected_keys = set(model_state.keys()) - allowed_keys
    if unexpected_keys:
        raise ValueError(f"Unexpected keys in checkpoint: {unexpected_keys}")
    
    return model_state

@app.route('/ai/model/restore', methods=['POST'])
def restore_ml_checkpoint():
    checkpoint_b64 = request.json.get('checkpoint_data')
    signature = request.json.get('signature')
    
    if not checkpoint_b64:
        return jsonify({'error': 'No checkpoint provided'}), 400
    
    if not signature:
        return jsonify({'error': 'No signature provided'}), 400
    
    try:
        serialized_state = base64.b64decode(checkpoint_b64)
    except Exception:
        return jsonify({'error': 'Invalid base64 encoding'}), 400
    
    # Verify the integrity signature before processing
    if not verify_checkpoint_signature(serialized_state, signature):
        return jsonify({'error': 'Invalid checkpoint signature'}), 403
    
    try:
        model_state = safe_deserialize_checkpoint(serialized_state)
    except ValueError as e:
        return jsonify({'error': str(e)}), 400
    
    return jsonify({'status': 'Model restored', 'layers': len(model_state.get('layers', []))})

if __name__ == '__main__':
    app.run(debug=False)
```
