Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions pyrit/models/storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,23 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]:
return container_name, blob_name
raise ValueError("Invalid blob URL")

def _resolve_blob_name(self, path: Union[Path, str]) -> str:
"""
Resolve a blob name from either a full blob URL or a relative blob path.

Args:
path (Union[Path, str]): Blob URL or relative blob path.

Returns:
str: The resolved blob name.
"""
path_str = str(path)
parsed_url = urlparse(path_str)
if parsed_url.scheme and parsed_url.netloc:
_, blob_name = self.parse_blob_url(path_str)
return blob_name
return path_str
Comment on lines +268 to +272

async def read_file(self, path: Union[Path, str]) -> bytes:
"""
Asynchronously reads the content of a file (blob) from Azure Blob Storage.
Expand Down Expand Up @@ -284,7 +301,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes:
if not self._client_async:
await self._create_container_client_async()

_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)

try:
blob_client = self._client_async.get_blob_client(blob=blob_name)
Expand All @@ -311,7 +328,7 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None:
"""
if not self._client_async:
await self._create_container_client_async()
_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)
try:
await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type)
except Exception as exc:
Expand All @@ -335,7 +352,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool:
if not self._client_async:
await self._create_container_client_async()
try:
_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)
blob_client = self._client_async.get_blob_client(blob=blob_name)
await blob_client.get_blob_properties()
return True
Expand All @@ -359,7 +376,7 @@ async def is_file(self, path: Union[Path, str]) -> bool:
if not self._client_async:
await self._create_container_client_async()
try:
_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)
blob_client = self._client_async.get_blob_client(blob=blob_name)
blob_properties = await blob_client.get_blob_properties()
return blob_properties.size > 0
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/models/test_storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ async def test_azure_blob_storage_io_read_file(azure_blob_storage_io):
assert result == b"Test file content"


@pytest.mark.asyncio
async def test_azure_blob_storage_io_read_file_with_relative_path(azure_blob_storage_io):
mock_container_client = AsyncMock()
azure_blob_storage_io._client_async = mock_container_client

mock_blob_client = AsyncMock()
mock_blob_stream = AsyncMock()

mock_container_client.get_blob_client = Mock(return_value=mock_blob_client)
mock_blob_client.download_blob = AsyncMock(return_value=mock_blob_stream)
mock_blob_stream.readall = AsyncMock(return_value=b"Test file content")
mock_container_client.close = AsyncMock()

result = await azure_blob_storage_io.read_file("dir1/dir2/sample.png")

assert result == b"Test file content"
mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/sample.png")


@pytest.mark.asyncio
async def test_azure_blob_storage_io_write_file():
container_url = "https://youraccount.blob.core.windows.net/yourcontainer"
Expand Down Expand Up @@ -143,6 +162,23 @@ async def test_azure_storage_io_path_exists(azure_blob_storage_io):
assert exists is True


@pytest.mark.asyncio
async def test_azure_storage_io_path_exists_with_relative_path(azure_blob_storage_io):
mock_container_client = AsyncMock()
azure_blob_storage_io._client_async = mock_container_client

mock_blob_client = AsyncMock()

mock_container_client.get_blob_client = Mock(return_value=mock_blob_client)
mock_blob_client.get_blob_properties = AsyncMock()
mock_container_client.close = AsyncMock()

exists = await azure_blob_storage_io.path_exists("dir1/dir2/blob_name.txt")

assert exists is True
mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/blob_name.txt")


@pytest.mark.asyncio
async def test_azure_storage_io_is_file(azure_blob_storage_io):
azure_blob_storage_io._client_async = AsyncMock()
Expand Down
Loading