# Pickle Deserialization via `pickle.loads()`

Language: Python
Severity: Critical
CWE: CWE-502

## Source
6

## Flow
6-8-9

## Sink
9

## Vulnerable Code
```python
import pickle
import base64
from flask import request, jsonify

def restore_ml_model_state():
    encrypted_state = request.headers.get('X-Model-Checkpoint')
    if not encrypted_state:
        return jsonify({'error': 'No checkpoint provided'}), 400
    try:
        decoded_checkpoint = base64.b64decode(encrypted_state)
        model_params = pickle.loads(decoded_checkpoint)
        return jsonify({'status': 'Model restored', 'layers': model_params.get('layers', 0)})
    except Exception as e:
        return jsonify({'error': 'Restoration failed'}), 500
```

## Explanation

The application accepts arbitrary pickled data from an HTTP header (X-Model-Checkpoint), base64-decodes it, and directly deserializes it using pickle.loads(). Since pickle can execute arbitrary Python code during deserialization, an attacker can craft a malicious pickle payload to achieve remote code execution on the server.

## Remediation

The fix replaces `pickle.loads()` with `json.loads()` to eliminate arbitrary code execution during deserialization, since JSON only supports safe primitive data types. Additionally, HMAC signature verification ensures that only checkpoints signed by the server's secret key are accepted, preventing tampering. Input validation with an allowlist of keys and value types provides defense-in-depth.

## Secure Code
```python
import json
import base64
import hmac
import hashlib
from flask import request, jsonify, current_app

ALLOWED_KEYS = {'layers', 'weights', 'optimizer_state', 'epoch', 'loss', 'learning_rate'}
ALLOWED_VALUE_TYPES = (int, float, str, list, bool, type(None))

def validate_model_params(params: dict) -> bool:
    if not isinstance(params, dict):
        return False
    for key, value in params.items():
        if key not in ALLOWED_KEYS:
            return False
        if not isinstance(value, ALLOWED_VALUE_TYPES):
            return False
    return True

def restore_ml_model_state():
    encrypted_state = request.headers.get('X-Model-Checkpoint')
    checkpoint_signature = request.headers.get('X-Checkpoint-Signature')
    if not encrypted_state:
        return jsonify({'error': 'No checkpoint provided'}), 400
    if not checkpoint_signature:
        return jsonify({'error': 'No checkpoint signature provided'}), 400
    try:
        decoded_checkpoint = base64.b64decode(encrypted_state)
        secret_key = current_app.config.get('CHECKPOINT_SECRET_KEY', '')
        if not secret_key:
            return jsonify({'error': 'Server configuration error'}), 500
        expected_sig = hmac.new(secret_key.encode(), decoded_checkpoint, hashlib.sha256).hexdigest()
        if not hmac.compare_digest(expected_sig, checkpoint_signature):
            return jsonify({'error': 'Invalid checkpoint signature'}), 403
        model_params = json.loads(decoded_checkpoint)
        if not validate_model_params(model_params):
            return jsonify({'error': 'Invalid model parameters'}), 400
        return jsonify({'status': 'Model restored', 'layers': model_params.get('layers', 0)})
    except (json.JSONDecodeError, ValueError):
        return jsonify({'error': 'Invalid checkpoint format'}), 400
    except Exception as e:
        return jsonify({'error': 'Restoration failed'}), 500
```
