{"title":"Pickle Deserialization via `pickle.loads` on Untrusted Data","language":"Python","severity":"Critical","cwe":"CWE-502","source_lines":[9],"flow_lines":[9,11,12],"sink_lines":[12],"vulnerable_code":"import pickle\nimport base64\nfrom flask import Flask, request, jsonify\n\napp = Flask(__name__)\n\n@app.route('/ai/model/restore', methods=['POST'])\ndef restore_ml_checkpoint():\n    checkpoint_b64 = request.json.get('checkpoint_data')\n    if not checkpoint_b64:\n        return jsonify({'error': 'No checkpoint provided'}), 400\n    serialized_state = base64.b64decode(checkpoint_b64)\n    model_state = pickle.loads(serialized_state)\n    return jsonify({'status': 'Model restored', 'layers': len(model_state.get('layers', []))})\n\nif __name__ == '__main__':\n    app.run(debug=True)","explanation":"The application deserializes untrusted pickle data from user input without validation. Python's pickle.loads() can execute arbitrary code during deserialization, allowing attackers to achieve remote code execution by crafting malicious pickle payloads.","remediation":"The fix replaces `pickle.loads()` with `json.loads()` for deserialization, eliminating the arbitrary code execution risk since JSON only supports safe primitive data types. Additionally, HMAC signature verification is added to ensure checkpoint data integrity and authenticity before processing, and input validation enforces an expected schema on the deserialized data.","secure_code":"import json\nimport base64\nimport hmac\nimport hashlib\nfrom flask import Flask, request, jsonify\n\napp = Flask(__name__)\n\n# In production, load this from a secure secret management system\nCHECKPOINT_SIGNING_SECRET = app.config.get('CHECKPOINT_SECRET', 'change-this-to-a-secure-random-secret')\n\ndef verify_checkpoint_signature(data_bytes, provided_signature):\n    \"\"\"Verify HMAC signature of checkpoint data to ensure integrity.\"\"\"\n    expected_signature = hmac.new(\n        CHECKPOINT_SIGNING_SECRET.encode(),\n        data_bytes,\n        hashlib.sha256\n    ).hexdigest()\n    return hmac.compare_digest(expected_signature, provided_signature)\n\ndef safe_deserialize_checkpoint(serialized_data):\n    \"\"\"Deserialize checkpoint data using JSON instead of pickle.\n    \n    Only allows safe data types: dicts, lists, strings, numbers, booleans, None.\n    \"\"\"\n    try:\n        model_state = json.loads(serialized_data)\n    except (json.JSONDecodeError, UnicodeDecodeError) as e:\n        raise ValueError(f\"Invalid checkpoint format: {e}\")\n    \n    if not isinstance(model_state, dict):\n        raise ValueError(\"Checkpoint must be a JSON object at the top level\")\n    \n    # Validate expected structure\n    allowed_keys = {'layers', 'weights', 'optimizer_state', 'epoch', 'loss', 'metadata'}\n    unexpected_keys = set(model_state.keys()) - allowed_keys\n    if unexpected_keys:\n        raise ValueError(f\"Unexpected keys in checkpoint: {unexpected_keys}\")\n    \n    return model_state\n\n@app.route('/ai/model/restore', methods=['POST'])\ndef restore_ml_checkpoint():\n    checkpoint_b64 = request.json.get('checkpoint_data')\n    signature = request.json.get('signature')\n    \n    if not checkpoint_b64:\n        return jsonify({'error': 'No checkpoint provided'}), 400\n    \n    if not signature:\n        return jsonify({'error': 'No signature provided'}), 400\n    \n    try:\n        serialized_state = base64.b64decode(checkpoint_b64)\n    except Exception:\n        return jsonify({'error': 'Invalid base64 encoding'}), 400\n    \n    # Verify the integrity signature before processing\n    if not verify_checkpoint_signature(serialized_state, signature):\n        return jsonify({'error': 'Invalid checkpoint signature'}), 403\n    \n    try:\n        model_state = safe_deserialize_checkpoint(serialized_state)\n    except ValueError as e:\n        return jsonify({'error': str(e)}), 400\n    \n    return jsonify({'status': 'Model restored', 'layers': len(model_state.get('layers', []))})\n\nif __name__ == '__main__':\n    app.run(debug=False)"}