Module ogr.services.forgejo.pull_request

Classes

class ForgejoPullRequest (raw_pr: pyforgejo.types.pull_request.PullRequest,
project: forgejo.ForgejoProject)
Expand source code
class ForgejoPullRequest(BasePullRequest):
    _target_project: "forgejo.ForgejoProject" = None
    _source_project: "forgejo.ForgejoProject" = None
    _labels: list[PRLabel] = None

    def __init__(
        self,
        raw_pr: PyforgejoPullRequest,
        project: "forgejo.ForgejoProject",
    ):
        super().__init__(raw_pr, project)

    def __str__(self) -> str:
        return "Forgejo" + super().__str__()

    @property
    def title(self) -> str:
        return self._raw_pr.title

    @title.setter
    def title(self, new_title: str) -> None:
        self.update_info(title=new_title)

    @property
    def id(self) -> int:
        return self._raw_pr.number

    @property
    def status(self) -> PRStatus:
        return PRStatus.merged if self._raw_pr.merged else PRStatus[self._raw_pr.state]

    @property
    def url(self) -> str:
        return self._raw_pr.url

    @property
    def description(self) -> str:
        return self._raw_pr.body

    @description.setter
    def description(self, new_description: str) -> None:
        self.update_info(description=new_description)

    @property
    def author(self) -> str:
        return self._raw_pr.user.login

    @property
    def source_branch(self) -> str:
        return self._raw_pr.head.ref

    @property
    def target_branch(self) -> str:
        return self._raw_pr.base.ref

    @property
    def created(self) -> datetime.datetime:
        return self._raw_pr.created_at

    @property
    def labels(self) -> list[PRLabel]:
        if not self._labels:
            self._labels = (
                [ForgejoPRLabel(raw_label, self) for raw_label in self._raw_pr.labels]
                if self._raw_pr.labels
                else []
            )
        return self._labels

    @property
    def diff_url(self) -> str:
        return self._raw_pr.diff_url

    @property
    def patch(self) -> bytes:
        patch_url = self._raw_pr.patch_url
        response = requests.get(patch_url)

        if not response.ok:
            raise OgrNetworkError(
                f"Couldn't get patch from {patch_url}.patch because {response.reason}.",
            )

        return response.content

    @property
    def head_commit(self) -> str:
        return self._raw_pr.head.sha

    @property
    def merge_commit_sha(self) -> Optional[str]:
        # this is None for non-merged PRs
        return self._raw_pr.merge_commit_sha

    @property
    def merge_commit_status(self) -> MergeCommitStatus:
        return (
            MergeCommitStatus.can_be_merged
            if self._raw_pr.mergeable
            else MergeCommitStatus.cannot_be_merged
        )

    @cached_property
    def source_project(self) -> "forgejo.ForgejoProject":
        pyforgejo_repo = self._raw_pr.head.repo
        return self._target_project.service.get_project(
            repo=pyforgejo_repo.name,
            namespace=pyforgejo_repo.owner.login,
            forgejo_repo=pyforgejo_repo,
        )

    @property
    def commits_url(self) -> str:
        return f"{self.url}/commits"

    @property
    def closed_by(self) -> Optional[str]:
        return self._raw_pr.merged_by.login if self._raw_pr.merged_by else None

    @staticmethod
    def create(
        project: "forgejo.ForgejoProject",
        title: str,
        body: str,
        target_branch: str,
        source_branch: str,
        fork_username: Optional[str] = None,
    ) -> "PullRequest":
        target_project = project

        if project.is_fork and fork_username is None:
            # handles fork -> upstream (called on fork)
            source_branch = f"{project.namespace}:{source_branch}"
            target_project = project.parent  # type: ignore
        elif fork_username:
            if fork_username != project.namespace and project.parent is not None:
                # handles fork -> other_fork
                #   (username of other_fork owner specified by fork_username)
                forks = list(
                    filter(
                        lambda fork: fork.namespace == fork_username,
                        project.parent.get_forks(),
                    ),
                )
                if not forks:
                    raise ForgejoAPIException("Requested fork doesn't exist")
                target_project = forks[0]  # type: ignore
                source_branch = f"{project.namespace}:{source_branch}"
            else:
                # handles fork -> upstream
                #   (username of fork owner specified by fork_username)
                source_branch = f"{fork_username}:{source_branch}"

        logger.debug(f"Creating PR {target_branch}<-{source_branch}")

        pr = target_project.api.repo_create_pull_request(
            owner=target_project.namespace,
            repo=target_project.repo,
            base=target_branch,
            body=body,
            head=source_branch,
            title=title,
        )
        logger.info(f"PR {pr.id} created.")

        return ForgejoPullRequest(pr, target_project)

    @staticmethod
    def get(project: "forgejo.ForgejoProject", pr_id: int) -> "PullRequest":
        try:
            raw_pr = project.api.repo_get_pull_request(
                owner=project.namespace,
                repo=project.repo,
                index=pr_id,
            )
        except NotFoundError as ex:
            raise ForgejoAPIException(f"No pull request with id {pr_id} found.") from ex
        return ForgejoPullRequest(raw_pr, project)

    @staticmethod
    def get_list(
        project: "forgejo.ForgejoProject",
        status: PRStatus = PRStatus.open,
    ) -> Iterable["PullRequest"]:
        prs = paginate(
            partial(
                project.api.repo_list_pull_requests,
                owner=project.namespace,
                repo=project.repo,
                # Forgejo has just open/closed/all
                state=status.name if status != PRStatus.merged else "closed",
            ),
        )
        return (ForgejoPullRequest(pr, project) for pr in prs)

    def update_info(
        self,
        title: Optional[str] = None,
        description: Optional[str] = None,
    ) -> "PullRequest":
        try:
            data = {"title": title if title else self.title}

            if description is not None:
                data["body"] = description

            updated_pr = self._target_project.api.repo_edit_pull_request(
                owner=self.target_project.namespace,
                repo=self.target_project.repo,
                index=self.id,
                **data,
            )

            self._raw_pr = updated_pr
            return self
        except Exception as ex:
            raise ForgejoAPIException(
                f"There was an error while updating Forgejo PR: {ex}",
            ) from ex

    def close(self) -> "PullRequest":
        self._raw_pr = self._target_project.api.repo_edit_pull_request(
            owner=self.target_project.namespace,
            repo=self.target_project.repo,
            index=self.id,
            state="closed",
        )
        return self

    def merge(self) -> "PullRequest":
        self._target_project.api.repo_merge_pull_request(
            owner=self.target_project.namespace,
            repo=self.target_project.repo,
            index=self.id,
            # options: merge, rebase, rebase-merge, squash, fast-forward-only, manually-merged
            do="merge",
        )
        return self.get(self._target_project, self.id)

    def add_label(self, *labels: str) -> None:
        issue_client = self._target_project.service.api.issue
        new_labels = issue_client.add_label(
            owner=self.target_project.namespace,
            repo=self.target_project.repo,
            index=self.id,
            labels=list(labels),
        )
        self._labels = [ForgejoPRLabel(raw_label, self) for raw_label in new_labels]

    def get_all_commits(self) -> Iterable[str]:
        return (
            commit.sha
            for commit in paginate(
                partial(
                    self._target_project.api.repo_get_pull_request_commits,
                    owner=self.target_project.namespace,
                    repo=self.target_project.repo,
                    index=self.id,
                ),
            )
        )

    def get_comments(
        self,
        filter_regex: Optional[str] = None,
        reverse: bool = False,
        author: Optional[str] = None,
    ) -> Union[list["PRComment"], Iterable["PRComment"]]:
        """
        Get list of pull request comments.

        Args:
            filter_regex: Filter the comments' content with `re.search`.

                Defaults to `None`, which means no filtering.
            reverse: Whether the comments are to be returned in
                reversed order.

                Defaults to `False`.
            author: Filter the comments by author.

                Defaults to `None`, which means no filtering.

        Returns:
            List of pull request comments.
        """
        raise NotImplementedError()

    def comment(
        self,
        body: str,
        commit: Optional[str] = None,
        filename: Optional[str] = None,
        row: Optional[int] = None,
    ) -> "PRComment":
        """
        Add new comment to the pull request.

        Args:
            body: Body of the comment.
            commit: Commit hash to which comment is related.

                Defaults to generic comment.
            filename: Path to the file to which comment is related.

                Defaults to no relation to the file.
            row: Line number to which the comment is related.

                Defaults to no relation to the line.

        Returns:
            Newly created comment.
        """
        raise NotImplementedError()

    def get_comment(self, comment_id: int) -> PRComment:
        """
        Returns a PR comment.

        Args:
            comment_id: id of comment

        Returns:
            Object representing a PR comment.
        """
        raise NotImplementedError()

    def get_statuses(self) -> Union[list[CommitFlag], Iterable[CommitFlag]]:
        """
        Returns statuses for latest commit on pull request.

        Returns:
            List of commit statuses of the latest commit.
        """
        raise NotImplementedError()

Attributes

project : GitProject
Project of the pull request.

Ancestors

Inherited members