diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 3555a3648..cf8af6edb 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -254,6 +254,23 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]: return container_name, blob_name raise ValueError("Invalid blob URL") + def _resolve_blob_name(self, path: Union[Path, str]) -> str: + """ + Resolve a blob name from either a full blob URL or a relative blob path. + + Args: + path (Union[Path, str]): Blob URL or relative blob path. + + Returns: + str: The resolved blob name. + """ + path_str = str(path) + parsed_url = urlparse(path_str) + if parsed_url.scheme and parsed_url.netloc: + _, blob_name = self.parse_blob_url(path_str) + return blob_name + return path_str + async def read_file(self, path: Union[Path, str]) -> bytes: """ Asynchronously reads the content of a file (blob) from Azure Blob Storage. @@ -284,7 +301,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: if not self._client_async: await self._create_container_client_async() - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) try: blob_client = self._client_async.get_blob_client(blob=blob_name) @@ -311,7 +328,7 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: """ if not self._client_async: await self._create_container_client_async() - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) try: await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type) except Exception as exc: @@ -335,7 +352,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool: if not self._client_async: await self._create_container_client_async() try: - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) blob_client = self._client_async.get_blob_client(blob=blob_name) await blob_client.get_blob_properties() return True @@ -359,7 +376,7 @@ async def is_file(self, path: Union[Path, str]) -> bool: if not self._client_async: await self._create_container_client_async() try: - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) blob_client = self._client_async.get_blob_client(blob=blob_name) blob_properties = await blob_client.get_blob_properties() return blob_properties.size > 0 diff --git a/tests/unit/models/test_storage_io.py b/tests/unit/models/test_storage_io.py index 67e3ece09..1aafa3c26 100644 --- a/tests/unit/models/test_storage_io.py +++ b/tests/unit/models/test_storage_io.py @@ -101,6 +101,25 @@ async def test_azure_blob_storage_io_read_file(azure_blob_storage_io): assert result == b"Test file content" +@pytest.mark.asyncio +async def test_azure_blob_storage_io_read_file_with_relative_path(azure_blob_storage_io): + mock_container_client = AsyncMock() + azure_blob_storage_io._client_async = mock_container_client + + mock_blob_client = AsyncMock() + mock_blob_stream = AsyncMock() + + mock_container_client.get_blob_client = Mock(return_value=mock_blob_client) + mock_blob_client.download_blob = AsyncMock(return_value=mock_blob_stream) + mock_blob_stream.readall = AsyncMock(return_value=b"Test file content") + mock_container_client.close = AsyncMock() + + result = await azure_blob_storage_io.read_file("dir1/dir2/sample.png") + + assert result == b"Test file content" + mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/sample.png") + + @pytest.mark.asyncio async def test_azure_blob_storage_io_write_file(): container_url = "https://youraccount.blob.core.windows.net/yourcontainer" @@ -143,6 +162,23 @@ async def test_azure_storage_io_path_exists(azure_blob_storage_io): assert exists is True +@pytest.mark.asyncio +async def test_azure_storage_io_path_exists_with_relative_path(azure_blob_storage_io): + mock_container_client = AsyncMock() + azure_blob_storage_io._client_async = mock_container_client + + mock_blob_client = AsyncMock() + + mock_container_client.get_blob_client = Mock(return_value=mock_blob_client) + mock_blob_client.get_blob_properties = AsyncMock() + mock_container_client.close = AsyncMock() + + exists = await azure_blob_storage_io.path_exists("dir1/dir2/blob_name.txt") + + assert exists is True + mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/blob_name.txt") + + @pytest.mark.asyncio async def test_azure_storage_io_is_file(azure_blob_storage_io): azure_blob_storage_io._client_async = AsyncMock()