Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 109 additions & 5 deletions python/lib/sift_client/_internal/low_level_wrappers/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
import logging
from typing import TYPE_CHECKING, Any, Sequence, cast

from sift.common.type.v1.resource_identifier_pb2 import ResourceIdentifier, ResourceIdentifiers
from sift.common.type.v1.resource_identifier_pb2 import (
NamedResources,
Names,
ResourceIdentifier,
ResourceIdentifiers,
)
from sift.rule_evaluation.v1.rule_evaluation_pb2 import (
AssetsTimeRange,
EvaluateRulesAnnotationOptions,
EvaluateRulesFromCurrentRuleVersions,
EvaluateRulesFromReportTemplate,
EvaluateRulesFromRuleVersions,
EvaluateRulesRequest,
EvaluateRulesResponse,
RunTimeRange,
Expand All @@ -15,6 +24,8 @@
ArchiveRuleRequest,
BatchArchiveRulesRequest,
BatchGetRulesRequest,
BatchGetRuleVersionsRequest,
BatchGetRuleVersionsResponse,
BatchUnarchiveRulesRequest,
BatchUpdateRulesRequest,
BatchUpdateRulesResponse,
Expand All @@ -24,7 +35,11 @@
CreateRuleResponse,
GetRuleRequest,
GetRuleResponse,
GetRuleVersionRequest,
GetRuleVersionResponse,
ListRulesRequest,
ListRuleVersionsRequest,
ListRuleVersionsResponse,
RuleAssetConfiguration,
RuleConditionExpression,
UnarchiveRuleRequest,
Expand All @@ -45,6 +60,7 @@
Rule,
RuleCreate,
RuleUpdate,
RuleVersion,
)
from sift_client.sift_types.tag import Tag
from sift_client.transport import GrpcClient, WithGrpcClient
Expand Down Expand Up @@ -506,6 +522,57 @@ async def list_all_rules(
max_results=max_results,
)

async def list_rule_versions(
self,
rule_id: str,
*,
filter_query: str | None = None,
order_by: str | None = None,
page_size: int | None = None,
page_token: str | None = None,
) -> tuple[list[RuleVersion], str]:
"""List rule versions for a rule.

Args:
rule_id: The rule ID to list versions for.
filter_query: Optional CEL filter (fields: rule_version_id, user_notes, change_message).
order_by: Unused, for _handle_pagination compatibility.
page_size: Maximum number of versions per page.
page_token: Token for the next page.

Returns:
Tuple of (list of RuleVersions, next page token or empty string).
"""
_ = order_by
request_kwargs: dict[str, Any] = {
"rule_id": rule_id,
"page_size": page_size or DEFAULT_PAGE_SIZE,
"page_token": page_token or "",
}
if filter_query:
request_kwargs["filter"] = filter_query
request = ListRuleVersionsRequest(**request_kwargs)
response = await self._grpc_client.get_stub(RuleServiceStub).ListRuleVersions(request)
response = cast("ListRuleVersionsResponse", response)
versions = [RuleVersion._from_proto(p) for p in response.rule_versions]
return versions, response.next_page_token or ""

async def list_all_rule_versions(
self,
rule_id: str,
*,
filter_query: str | None = None,
max_results: int | None = None,
page_size: int | None = DEFAULT_PAGE_SIZE,
) -> list[RuleVersion]:
"""List all rule versions for a rule, with optional CEL filter."""
return await self._handle_pagination(
self.list_rule_versions,
kwargs={"rule_id": rule_id, "filter_query": filter_query},
page_size=page_size,
max_results=max_results,
)

async def evaluate_rules(
self,
*,
Expand Down Expand Up @@ -571,13 +638,22 @@ async def evaluate_rules(
if all_applicable_rules:
kwargs["all_applicable_rules"] = all_applicable_rules
if rule_ids:
kwargs["rules"] = {"rules": ResourceIdentifiers(ids={"ids": rule_ids})} # type: ignore
kwargs["rules"] = EvaluateRulesFromCurrentRuleVersions(
rules=ResourceIdentifiers(ids={"ids": rule_ids}) # type: ignore[arg-type]
)
if rule_version_ids:
kwargs["rule_versions"] = rule_version_ids
kwargs["rule_versions"] = EvaluateRulesFromRuleVersions(
rule_version_ids=rule_version_ids
)
if report_template_id:
kwargs["report_template"] = report_template_id
kwargs["report_template"] = EvaluateRulesFromReportTemplate(
report_template=ResourceIdentifier(id=report_template_id)
)
if tags:
kwargs["tags"] = [tag.name if isinstance(tag, Tag) else tag for tag in tags]
tag_names = [tag.name if isinstance(tag, Tag) else tag for tag in tags]
kwargs["annotation_options"] = EvaluateRulesAnnotationOptions(
tags=NamedResources(names=Names(names=tag_names)) # type: ignore[arg-type]
)
if report_name:
kwargs["report_name"] = report_name
if organization_id:
Expand All @@ -595,3 +671,31 @@ async def evaluate_rules(
report = await ReportsLowLevelClient(self._grpc_client).get_report(report_id=report_id)
return created_annotation_count, report, job_id
return created_annotation_count, None, job_id

async def get_rule_version(self, rule_version_id: str) -> Rule:
"""Get a rule at a specific version by rule_version_id.

Args:
rule_version_id: The rule version ID to get.

Returns:
The Rule at that version.
"""
request = GetRuleVersionRequest(rule_version_id=rule_version_id)
response = await self._grpc_client.get_stub(RuleServiceStub).GetRuleVersion(request)
grpc_rule = cast("GetRuleVersionResponse", response).rule
return Rule._from_proto(grpc_rule)

async def batch_get_rule_versions(self, rule_version_ids: list[str]) -> list[Rule]:
"""Get multiple rules at specific versions by rule_version_ids.

Args:
rule_version_ids: The rule version IDs to get.

Returns:
List of Rules at those versions (order may match request order).
"""
request = BatchGetRuleVersionsRequest(rule_version_ids=rule_version_ids)
response = await self._grpc_client.get_stub(RuleServiceStub).BatchGetRuleVersions(request)
response = cast("BatchGetRuleVersionsResponse", response)
return [Rule._from_proto(r) for r in response.rules]
50 changes: 39 additions & 11 deletions python/lib/sift_client/_tests/resources/test_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ def test_client_binding(sift_client):

@pytest.mark.integration
class TestReports:
def test_create_from_rule_versions(self, nostromo_run, test_rule, sift_client):
"""Create a report from specific rule version IDs."""
rule_versions = sift_client.rules.list_rule_versions(test_rule)
assert rule_versions, "test_rule should have at least one version"
report = sift_client.reports.create_from_rule_versions(
name="report_from_rule_versions",
run=nostromo_run,
organization_id=nostromo_run.organization_id,
rule_versions=[rule_versions[0].rule_version_id],
)
assert report is not None
assert report.run_id == nostromo_run.id_
assert report.name == "report_from_rule_versions"

def test_create_from_rule_versions_with_rule_version_objects(
self, nostromo_run, test_rule, sift_client
):
"""Create a report passing RuleVersion instances."""
rule_versions = sift_client.rules.list_rule_versions(test_rule)
assert rule_versions
report = sift_client.reports.create_from_rule_versions(
name="report_from_rule_versions_objs",
run=nostromo_run,
organization_id=nostromo_run.organization_id,
rule_versions=rule_versions[:1],
)
assert report is not None
assert report.run_id == nostromo_run.id_

def test_create_from_rules(self, nostromo_run, test_rule, sift_client):
report_from_rules = sift_client.reports.create_from_rules(
name="report_from_rules",
Expand Down Expand Up @@ -146,17 +175,16 @@ def test_archive(self, nostromo_run, test_rule, sift_client):
assert archived_report is not None
assert archived_report.is_archived == True

def test_unarchive(self, sift_client):
reports_from_rules = sift_client.reports.list_(
name="report_from_rules", include_archived=True
def test_unarchive(self, nostromo_run, test_rule, sift_client):
# create a report, archive it, then unarchive it
report_from_rules = sift_client.reports.create_from_rules(
name="report_from_rules_unarchive",
run=nostromo_run,
rules=[test_rule],
)
report_from_rules = None
for report_from_rules in reports_from_rules:
if report_from_rules.is_archived:
report_from_rules = report_from_rules
break
assert report_from_rules is not None
assert report_from_rules.is_archived == True
unarchived_report = sift_client.reports.unarchive(report=report_from_rules)
archived_report = sift_client.reports.archive(report=report_from_rules)
assert archived_report.is_archived is True
unarchived_report = sift_client.reports.unarchive(report=archived_report)
assert unarchived_report is not None
assert unarchived_report.is_archived == False
assert unarchived_report.is_archived is False
113 changes: 113 additions & 0 deletions python/lib/sift_client/_tests/resources/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RuleAnnotationType,
RuleCreate,
RuleUpdate,
RuleVersion,
)

pytestmark = pytest.mark.integration
Expand Down Expand Up @@ -215,6 +216,118 @@ async def test_list_with_time_filters(self, rules_api_async):
for rule in rules:
assert rule.created_date >= one_year_ago

class TestListRuleVersions:
"""Tests for the async list_rule_versions method."""

@pytest.mark.asyncio
async def test_list_rule_versions_by_rule(self, rules_api_async, test_rule):
"""Test listing rule versions for a rule."""
versions = await rules_api_async.list_rule_versions(test_rule)
assert isinstance(versions, list)
assert len(versions) >= 1
for v in versions:
assert isinstance(v, RuleVersion)
assert v.rule_id == test_rule.id_
assert v.rule_version_id
assert v.version
assert v.created_date

@pytest.mark.asyncio
async def test_list_rule_versions_by_rule_id_str(self, rules_api_async, test_rule):
"""Test listing rule versions by rule ID string."""
versions = await rules_api_async.list_rule_versions(test_rule.id_)
assert isinstance(versions, list)
assert len(versions) >= 1
for v in versions:
assert v.rule_id == test_rule.id_

@pytest.mark.asyncio
async def test_list_rule_versions_with_limit(self, rules_api_async, test_rule):
"""Test listing rule versions with limit."""
versions = await rules_api_async.list_rule_versions(test_rule, limit=1)
assert isinstance(versions, list)
assert len(versions) <= 1
if versions:
assert isinstance(versions[0], RuleVersion)

@pytest.mark.asyncio
async def test_list_rule_versions_with_rule_version_ids_filter(
self, rules_api_async, test_rule
):
"""Test listing rule versions filtered by rule_version_ids."""
all_versions = await rules_api_async.list_rule_versions(test_rule)
assert all_versions
first_id = all_versions[0].rule_version_id
versions = await rules_api_async.list_rule_versions(
test_rule, rule_version_ids=[first_id]
)
assert len(versions) == 1
assert versions[0].rule_version_id == first_id

class TestGetRuleVersion:
"""Tests for the async get_rule_version method."""

@pytest.mark.asyncio
async def test_get_rule_version_by_id(self, rules_api_async, test_rule):
"""Test getting a rule at a specific version by rule_version_id."""
versions = await rules_api_async.list_rule_versions(test_rule)
assert versions
rule_at_version = await rules_api_async.get_rule_version(versions[0].rule_version_id)
assert rule_at_version is not None
assert rule_at_version.id_ == test_rule.id_
assert rule_at_version.rule_version is not None
assert rule_at_version.rule_version.rule_version_id == versions[0].rule_version_id

@pytest.mark.asyncio
async def test_get_rule_version_by_rule_version_instance(self, rules_api_async, test_rule):
"""Test getting a rule at a specific version by passing RuleVersion instance."""
versions = await rules_api_async.list_rule_versions(test_rule)
assert versions
rule_at_version = await rules_api_async.get_rule_version(versions[0])
assert rule_at_version is not None
assert rule_at_version.id_ == test_rule.id_
assert rule_at_version.rule_version.rule_version_id == versions[0].rule_version_id

class TestBatchGetRuleVersions:
"""Tests for the async batch_get_rule_versions method."""

@pytest.mark.asyncio
async def test_batch_get_rule_versions_by_ids(self, rules_api_async, test_rule):
"""Test batch getting rules by rule_version_id strings."""
versions = await rules_api_async.list_rule_versions(test_rule)
assert versions
ids = [v.rule_version_id for v in versions[:2]]
rules = await rules_api_async.batch_get_rule_versions(ids)
assert len(rules) == len(ids)
returned_ids = {r.rule_version.rule_version_id for r in rules if r.rule_version}
assert returned_ids >= set(ids)
for r in rules:
assert r.id_ == test_rule.id_

@pytest.mark.asyncio
async def test_batch_get_rule_versions_by_rule_version_instances(
self, rules_api_async, test_rule
):
"""Test batch getting rules by passing RuleVersion instances."""
versions = await rules_api_async.list_rule_versions(test_rule)
assert versions
rules = await rules_api_async.batch_get_rule_versions(versions[:2])
assert len(rules) <= 2
for r in rules:
assert r.id_ == test_rule.id_
if len(versions) >= 2:
assert len(rules) == 2

@pytest.mark.asyncio
async def test_batch_get_rule_versions_single(self, rules_api_async, test_rule):
"""Test batch_get_rule_versions with a single version ID."""
versions = await rules_api_async.list_rule_versions(test_rule)
assert versions
rules = await rules_api_async.batch_get_rule_versions([versions[0].rule_version_id])
assert len(rules) == 1
assert rules[0].id_ == test_rule.id_
assert rules[0].rule_version.rule_version_id == versions[0].rule_version_id

class TestFind:
"""Tests for the async find method."""

Expand Down
Loading
Loading