Skip to content

Commit 4e7411b

Browse files
authored
fix: Guard against ssrf attacks when creating image input (#4068)
# Description This pull request add a check to guard against server-side request forgery (ssrf) attacks when the executor service invokes create_image_from_url see: aka.ms/antissrf # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](https://github.com/microsoft/promptflow/blob/main/CONTRIBUTING.md).** - [ ] **I confirm that all new dependencies are compatible with the MIT license.** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes.
1 parent 02f0656 commit 4e7411b

File tree

7 files changed

+391
-6
lines changed

7 files changed

+391
-6
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Export the main classes for easy import
2+
from .anti_ssrf import AntiSSRF
3+
from .anti_ssrf_policy import AntiSSRFPolicy
4+
from .exceptions import AntiSSRFException
5+
6+
# Make classes available for import
7+
__all__ = ["AntiSSRF", "AntiSSRFPolicy", "AntiSSRFException"]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Export the main classes for easy import
2+
from typing import List, Optional
3+
4+
from .anti_ssrf_policy import AntiSSRFPolicy
5+
from .exceptions import AntiSSRFException
6+
7+
8+
# Simple wrapper class that provides validate_url method
9+
class AntiSSRF:
10+
"""Anti-SSRF protection class that validates URLs and network connections."""
11+
12+
def __init__(self, policy: Optional[AntiSSRFPolicy] = None):
13+
"""Initialize AntiSSRF with an optional custom policy."""
14+
self.policy: AntiSSRFPolicy = policy if policy is not None else AntiSSRFPolicy(use_defaults=True)
15+
16+
def validate_url(self, url: str, headers={}) -> None:
17+
"""
18+
Validate a URL against the Anti-SSRF policy.
19+
20+
Args:
21+
url: The URL to validate
22+
23+
Raises:
24+
AntiSSRFException: If the URL is not allowed by the policy
25+
"""
26+
if not url:
27+
return
28+
29+
from urllib.parse import urlparse
30+
31+
# Parse the URL
32+
try:
33+
parsed_url = urlparse(url)
34+
except Exception as e:
35+
raise AntiSSRFException(f"Invalid URL format: {e}")
36+
37+
if not parsed_url.hostname:
38+
raise AntiSSRFException("URL must have a hostname")
39+
40+
# Resolve DNS and check network connections
41+
if parsed_url.hostname != "registries" and parsed_url.hostname != "location.api.azureml.ms":
42+
dns_resolved_ips = self._resolve_hostname(parsed_url.hostname)
43+
44+
if not self.policy.is_network_connection_allowed(dns_resolved_ips):
45+
raise AntiSSRFException(f"Network connection to '{parsed_url.hostname}' is not allowed")
46+
47+
# Check HTTP scheme
48+
if not self.policy.is_http_request_allowed(parsed_url.scheme, headers):
49+
raise AntiSSRFException(f"HTTP scheme '{parsed_url.scheme}' is not allowed")
50+
51+
def _resolve_hostname(self, hostname: str) -> List[str]:
52+
"""Resolve hostname to IP addresses."""
53+
import ipaddress
54+
import socket
55+
56+
# Handle localhost explicitly
57+
if hostname.lower() == "localhost":
58+
return ["127.0.0.1"]
59+
60+
# Try to parse as IP address first
61+
try:
62+
ip_address = ipaddress.ip_address(hostname)
63+
return [str(ip_address)]
64+
except ValueError:
65+
pass # Not an IP address, continue with DNS resolution
66+
67+
# Perform DNS resolution
68+
try:
69+
_, _, ip_addresses = socket.gethostbyname_ex(hostname)
70+
if not ip_addresses:
71+
raise AntiSSRFException(f"No IP addresses resolved for hostname: {hostname}")
72+
return ip_addresses
73+
except socket.gaierror as e:
74+
raise AntiSSRFException(f"DNS resolution failed for hostname '{hostname}': {e}")
75+
except Exception as e:
76+
raise AntiSSRFException(f"Error resolving hostname '{hostname}': {e}")
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import ipaddress
2+
from typing import List, Optional
3+
4+
from .cidr_helpers import IPNetwork, try_parse_cidr_string
5+
from .exceptions import AntiSSRFException
6+
7+
8+
class AntiSSRFPolicy:
9+
def __init__(self, use_defaults: bool = True):
10+
self.AllowedAddresses: List[IPNetwork] = []
11+
self.DeniedAddresses: List[IPNetwork] = []
12+
self.DeniedHeaders: List[str] = []
13+
self.RequiredHeaders: List[str] = []
14+
self.AllowPlainTextHttp: bool = False
15+
self.AddXFFHeader: bool = True
16+
self.DenyAllUnspecifiedIPs: bool = False
17+
18+
if use_defaults:
19+
self._set_defaults()
20+
21+
def add_allowed_addresses(self, networks: List[str]) -> bool:
22+
for network in networks:
23+
outnet = try_parse_cidr_string(network)
24+
self.AllowedAddresses.append(outnet)
25+
return True
26+
27+
def add_denied_addresses(self, networks: List[str]) -> bool:
28+
if self.DenyAllUnspecifiedIPs:
29+
raise AntiSSRFException("Can't add denied networks when * is already supplied")
30+
if not networks:
31+
raise AntiSSRFException("Bad networks parameter")
32+
if len(networks) == 1 and networks[0] == "*":
33+
if len(self.DeniedAddresses) > 0:
34+
raise AntiSSRFException("Can't add * when deny list already has entries")
35+
self.DenyAllUnspecifiedIPs = True
36+
return True
37+
else:
38+
for network in networks:
39+
outnet = try_parse_cidr_string(network)
40+
self.DeniedAddresses.append(outnet)
41+
return True
42+
43+
def add_denied_headers(self, denied_headers: Optional[List[str]]) -> None:
44+
if denied_headers:
45+
self.DeniedHeaders.extend(denied_headers)
46+
47+
def add_required_headers(self, required_headers: Optional[List[str]]) -> None:
48+
if required_headers:
49+
self.RequiredHeaders.extend(required_headers)
50+
51+
def set_allow_plain_text_http(self, allow_plain_text_http: bool = False) -> None:
52+
self.AllowPlainTextHttp = allow_plain_text_http
53+
54+
def add_xff(self, add_xff: bool = True) -> None:
55+
self.AddXFFHeader = add_xff
56+
57+
# IP Addresses in Deny List can be IPv4, IPv6 or IPv4 mapped to IPv6
58+
# Accordingly, to check if an input address from DNS resolution is to be denied, we should:
59+
# 1. Check if the input IP is an IPv4 address, and then check if it is present in deny list
60+
# as a pure IPv4 or an IPv4 mapped to IPv6 format
61+
# 2. Check if the input IP is an IPv6 address, and then check if it is present in deny list
62+
# as a pure IPv6 address. This includes addresses in IPv4 mapped to IPv6 format
63+
# 3. Check if the input IP is an IPv4 mapped to IPv6, then check if it is present in the
64+
# deny list as an IPv4 mapped IPv6 address, then convert it to IPv4 and check if it is
65+
# present in the deny list as a pure IPv4 address
66+
#
67+
# For example, 169.254.169.254, if present in the deny list, should deny DNS resolved
68+
# addresses 169.254.169.254 and ::ffff:a9fe:a9fe
69+
# Likewise ::ffff:a9fe:a9fe, if present in the deny list, should deny DNS resolved
70+
# addresses ::ffff:a9fe:a9fe and 169.254.169.254
71+
#
72+
# Such case-by-case comparisons leads to a lot of branches in code leading to
73+
# sphagettification and also makes code difficult to follow and maintain
74+
# Furthermore, the complexity gets compounded if one adds an allow list to the mix
75+
#
76+
# To make things easier and efficient, we convert every IPv4 address to IPv6 across the
77+
# deny list, allow list and also the input DNS resolved addresses
78+
# The CIDR helper class is accordingly written
79+
#
80+
# As IPv6 is the future anyway, this also makes the code future proof
81+
def is_network_connection_allowed(self, dns_resolved_ip_addresses: List[str]) -> bool:
82+
for ip_str in dns_resolved_ip_addresses:
83+
ip_address = ipaddress.ip_address(ip_str)
84+
ipv6_address = (
85+
ip_address
86+
if isinstance(ip_address, ipaddress.IPv6Address)
87+
else ipaddress.IPv6Address(f"::ffff:{ip_address}")
88+
)
89+
90+
if self.DenyAllUnspecifiedIPs:
91+
# If the address is not in an allow list, it's not allowed.
92+
if not self._networks_contain_address(self.AllowedAddresses, ipv6_address):
93+
return False
94+
elif self.DeniedAddresses:
95+
# If address is in deny list and not in allow list, it's not allowed.
96+
if self._networks_contain_address(
97+
self.DeniedAddresses, ipv6_address
98+
) and not self._networks_contain_address(self.AllowedAddresses, ipv6_address):
99+
return False
100+
# No IP addresses returned by DNS resolution were denied
101+
return True
102+
103+
@staticmethod
104+
def _networks_contain_address(networks: List[IPNetwork], address: ipaddress.IPv6Address) -> bool:
105+
for network in networks:
106+
if network.contains(address):
107+
return True
108+
return False
109+
110+
def is_http_request_allowed(self, scheme: str, headers: dict) -> bool:
111+
if scheme.lower() == "http" and not self.AllowPlainTextHttp:
112+
return False
113+
114+
if self.AddXFFHeader:
115+
if "X-Forwarded-For" not in headers:
116+
headers["X-Forwarded-For"] = "true"
117+
118+
if self.DeniedHeaders:
119+
for header in self.DeniedHeaders:
120+
if header in headers:
121+
return False
122+
123+
if self.RequiredHeaders:
124+
for header in self.RequiredHeaders:
125+
if header not in headers:
126+
return False
127+
128+
return True
129+
130+
def _set_defaults(self):
131+
self.AllowedAddresses = []
132+
self.DeniedAddresses = []
133+
self.RequiredHeaders = []
134+
self.DeniedHeaders = []
135+
self.AllowPlainTextHttp = False
136+
self.DenyAllUnspecifiedIPs = False
137+
self.AddXFFHeader = True
138+
139+
self.add_denied_addresses(
140+
[
141+
# ==== IPv4 ==== #
142+
"255.255.255.255/32",
143+
"168.63.129.16/32", # Not nonroutable,
144+
# but this is the WireServer IP we should block.
145+
"192.0.0.0/24",
146+
"192.0.2.0/24",
147+
"192.88.99.0/24",
148+
"198.51.100.0/24",
149+
"203.0.113.0/24",
150+
"169.254.0.0/16",
151+
"192.168.0.0/16",
152+
"198.18.0.0/15",
153+
"172.16.0.0/12",
154+
"100.64.0.0/10", # IANA-Reserved
155+
"0.0.0.0/8",
156+
"10.0.0.0/8",
157+
"127.0.0.0/8",
158+
"25.0.0.0/8", # GNS Core
159+
"224.0.0.0/4",
160+
"240.0.0.0/4",
161+
# ==== IPv6 ==== #
162+
"::1/128", # Localhost
163+
"FC00::/7", # Unique-local
164+
"fe80::/10", # Link-local
165+
"fec0::/10", # Site-local
166+
"2001::/32", # Teredo
167+
]
168+
)
169+
self.DenyAllUnspecifiedIPs = False
170+
171+
# Deprecated method, for backward compatibility only
172+
def set_defaults(self):
173+
self._set_defaults()
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import ipaddress
2+
from typing import Union
3+
4+
from .exceptions import AntiSSRFException
5+
6+
7+
class IPNetwork:
8+
def __init__(self, ip: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address], prefix: int) -> None:
9+
self._base_address = ipaddress.ip_network(f"{ip}/{prefix}", strict=False)
10+
self._prefix_length = prefix
11+
12+
def contains(self, ip: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address]) -> bool:
13+
ip_obj = ip
14+
if isinstance(ip, str):
15+
ip_obj = ipaddress.ip_address(ip)
16+
# Convert IPv4 to IPv6-mapped if base is IPv6
17+
if self._base_address.version == 6 and ip_obj.version == 4:
18+
ip_obj = ipaddress.IPv6Address(f"::ffff:{ip_obj}")
19+
return ip_obj in self._base_address
20+
21+
22+
def _parse_ip_address(ip_string: str) -> Union[ipaddress.IPv4Address, ipaddress.IPv6Address]:
23+
"""Parse IP address from string, raising AntiSSRFException on error."""
24+
try:
25+
return ipaddress.ip_address(ip_string)
26+
except ValueError as e:
27+
raise AntiSSRFException("Bad CIDR", e)
28+
29+
30+
def _parse_prefix_length(prefix_string: str) -> int:
31+
"""Parse prefix length from string, raising AntiSSRFException on error."""
32+
try:
33+
return int(prefix_string)
34+
except ValueError as e:
35+
raise AntiSSRFException("Bad CIDR", e)
36+
37+
38+
def _create_single_ip_network(ip: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]) -> IPNetwork:
39+
"""Create network for single IP address (no prefix specified)."""
40+
if ip.version == 4:
41+
# IPv4 mapped to IPv6, /128
42+
return IPNetwork(f"::ffff:{ip}", 128)
43+
elif ip.version == 6:
44+
return IPNetwork(ip, 128)
45+
else:
46+
raise AntiSSRFException("Bad CIDR")
47+
48+
49+
def _create_prefixed_network(ip: Union[ipaddress.IPv4Address, ipaddress.IPv6Address], prefix_length: int) -> IPNetwork:
50+
"""Create network for IP address with prefix."""
51+
if ip.version == 4 and 0 <= prefix_length <= 32:
52+
# IPv4 mapped to IPv6, prefix + 96
53+
return IPNetwork(f"::ffff:{ip}", prefix_length + 96)
54+
elif ip.version == 6 and 0 <= prefix_length <= 128:
55+
return IPNetwork(ip, prefix_length)
56+
else:
57+
raise AntiSSRFException("Bad CIDR")
58+
59+
60+
# Try parse CIDR string
61+
# Returns an IPNetwork object if everything went fine, or throws an exception
62+
# For easy computation of allow/deny, every IP Address is converted into an IPv6 address
63+
def try_parse_cidr_string(cidr_string: str) -> IPNetwork:
64+
parts = cidr_string.split("/")
65+
ip = _parse_ip_address(parts[0])
66+
67+
if len(parts) == 1:
68+
# e.g. "127.0.0.1" or "::ffff:909:909"
69+
return _create_single_ip_network(ip)
70+
elif len(parts) == 2:
71+
# Cases such as "127.0.0.1/2" or "::ffff:909:909/80"
72+
prefix_length = _parse_prefix_length(parts[1])
73+
return _create_prefixed_network(ip, prefix_length)
74+
else:
75+
raise AntiSSRFException("Bad CIDR")
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class AntiSSRFException(Exception):
2+
def __init__(self, message=None, inner=None):
3+
if inner is not None:
4+
super().__init__(message, inner)
5+
elif message is not None:
6+
super().__init__(message)
7+
else:
8+
super().__init__()

src/promptflow-core/promptflow/_utils/multimedia_utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from promptflow._constants import MessageFormatType
1414
from promptflow._utils._errors import InvalidImageInput, InvalidMessageFormatType, LoadMultimediaDataError
15+
from promptflow._utils.anti_ssrf import AntiSSRF, AntiSSRFException
1516
from promptflow._utils.yaml_utils import load_yaml
1617
from promptflow.contracts.flow import FlowInputDefinition
1718
from promptflow.contracts.multimedia import Image, PFBytes, Text
@@ -93,8 +94,30 @@ def create_image_from_base64(base64_str: str, mime_type: str = None):
9394
return Image(image_bytes, mime_type=mime_type)
9495

9596
@staticmethod
96-
def create_image_from_url(url: str, mime_type: str = None):
97-
response = requests.get(url)
97+
def create_image_from_url(url: str, mime_type: str = None) -> Image:
98+
anti_ssrf = AntiSSRF()
99+
anti_ssrf.policy.set_allow_plain_text_http(True)
100+
101+
def block_redirect_if_ssrf(response: requests.Response, *args, **kwargs) -> None:
102+
if not response.is_redirect:
103+
return
104+
105+
anti_ssrf.validate_url(response.headers["Location"])
106+
107+
try:
108+
anti_ssrf.validate_url(url)
109+
110+
# Use the requests "response" hook to allow us to inspect each response
111+
# in a redirect chain.
112+
# See: https://requests.readthedocs.io/en/latest/user/advanced/#event-hooks
113+
response = requests.get(url, hooks={"response": block_redirect_if_ssrf})
114+
except AntiSSRFException as e:
115+
raise InvalidImageInput(
116+
message_format="Failed to fetch image from URL: {url}.",
117+
target=ErrorTarget.EXECUTOR,
118+
url=url,
119+
) from e
120+
98121
if response.status_code == 200:
99122
return Image(response.content, mime_type=mime_type, source_url=url)
100123
else:

0 commit comments

Comments
 (0)