# Unsafe `marshal` Deserialization via `marshal.loads()`

Language: Python
Severity: Critical
CWE: CWE-502

## Source
9

## Flow
9-10-11

## Sink
11

## Vulnerable Code
```python
import marshal
import base64
from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route('/api/ml/model/predict', methods=['POST'])
def execute_ml_inference():
    serialized_weights = request.json.get('model_weights')
    decoded_weights = base64.b64decode(serialized_weights)
    weight_obj = marshal.loads(decoded_weights)
    prediction_result = weight_obj['predict_fn'](request.json.get('input_data'))
    return jsonify({'prediction': prediction_result, 'status': 'success'})
```

## Explanation

The code deserializes untrusted data from client requests using marshal.loads(), which can execute arbitrary code. An attacker can craft a malicious payload containing executable code in the 'model_weights' parameter, which gets deserialized and executed when the 'predict_fn' callable is invoked.

## Remediation

The fix eliminates unsafe marshal deserialization entirely by replacing it with a server-side model registry pattern. Instead of accepting serialized executable code from untrusted clients, clients now reference pre-registered models by ID, and only safe JSON input data is accepted and validated before being passed to trusted prediction functions.

## Secure Code
```python
import json
import hashlib
import hmac
import os
from flask import Flask, request, jsonify

app = Flask(__name__)

# Pre-registered models stored server-side, never deserialized from client input
MODEL_REGISTRY = {}

# Secret key for HMAC verification of model identifiers
HMAC_SECRET = os.environ.get('MODEL_HMAC_SECRET', os.urandom(32))


def register_model(model_id, predict_fn):
    """Register a trusted model prediction function server-side."""
    MODEL_REGISTRY[model_id] = predict_fn


def validate_input_data(input_data):
    """Validate that input data contains only safe numeric/list types."""
    if input_data is None:
        raise ValueError("input_data is required")
    if isinstance(input_data, (int, float)):
        return input_data
    if isinstance(input_data, list):
        return [validate_input_data(item) for item in input_data]
    if isinstance(input_data, dict):
        return {k: validate_input_data(v) for k, v in input_data.items()}
    raise ValueError(f"Unsupported input data type: {type(input_data).__name__}")


@app.route('/api/ml/model/predict', methods=['POST'])
def execute_ml_inference():
    if not request.json:
        return jsonify({'error': 'Request body must be JSON'}), 400

    model_id = request.json.get('model_id')
    if not model_id or not isinstance(model_id, str):
        return jsonify({'error': 'A valid model_id is required'}), 400

    # Look up model from server-side registry instead of deserializing from client
    if model_id not in MODEL_REGISTRY:
        return jsonify({'error': f'Model "{model_id}" not found in registry'}), 404

    predict_fn = MODEL_REGISTRY[model_id]

    # Validate input data
    try:
        input_data = validate_input_data(request.json.get('input_data'))
    except ValueError as e:
        return jsonify({'error': str(e)}), 400

    try:
        prediction_result = predict_fn(input_data)
    except Exception as e:
        return jsonify({'error': f'Prediction failed: {str(e)}'}), 500

    return jsonify({'prediction': prediction_result, 'status': 'success'})


# Example: register models at startup
# register_model('linear_model_v1', my_linear_predict_function)
```
