Skip to content

feat: Add support for Amazon AWS and Google Drive #511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions datashuttle/configs/canonical_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,16 @@ def get_canonical_configs() -> dict:
canonical_configs = {
"local_path": Union[str, Path],
"central_path": Optional[Union[str, Path]],
"connection_method": Optional[Literal["ssh", "local_filesystem"]],
"connection_method": Optional[
Literal["ssh", "local_filesystem", "aws", "gdrive"]
],
"central_host_id": Optional[str],
"central_host_username": Optional[str],
# AWS S3 specific configs
"aws_bucket_name": Optional[str],
"aws_region": Optional[str],
# Google Drive specific configs
"gdrive_folder_id": Optional[str],
}

return canonical_configs
Expand Down Expand Up @@ -101,43 +108,57 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None:

raise_on_bad_local_only_project_configs(config_dict)

# Check connection-specific requirements AFTER the general consistency check
if config_dict["connection_method"] == "ssh" and (
not config_dict["central_host_id"]
or not config_dict["central_host_username"]
):
utils.log_and_raise_error(
"SSH connection requires both 'central_host_id' and 'central_host_username'.",
ConfigError,
)
elif (
config_dict["connection_method"] == "aws"
and not config_dict["aws_bucket_name"]
):
utils.log_and_raise_error(
"AWS connection requires 'aws_bucket_name' to be specified.",
ConfigError,
)
elif (
config_dict["connection_method"] == "gdrive"
and not config_dict["gdrive_folder_id"]
):
utils.log_and_raise_error(
"Google Drive connection requires 'gdrive_folder_id' to be specified.",
ConfigError,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice input checks!

if list(config_dict.keys()) != list(canonical_dict.keys()):
utils.log_and_raise_error(
f"New config keys are in the wrong order. The"
f" order should be: {canonical_dict.keys()}.",
f" order should be: {list(canonical_dict.keys())}.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! This could be a separate PR

ConfigError,
)

raise_on_bad_path_syntax(
config_dict["local_path"].as_posix(), "local_path"
)

if config_dict["central_path"] is not None:
raise_on_bad_path_syntax(
config_dict["central_path"].as_posix(), "central_path"
)

# Check SSH settings
if config_dict["connection_method"] == "ssh" and (
not config_dict["central_host_id"]
or not config_dict["central_host_username"]
):
utils.log_and_raise_error(
"'central_host_id' and 'central_host_username' are "
"required if 'connection_method' is 'ssh'.",
ConfigError,
)

# Initialise the local project folder
utils.print_message_to_user(
f"Making project folder at: {config_dict['local_path']}"
)
try:
folders.create_folders(config_dict["local_path"])
except OSError:
except OSError as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, because error messages are propagated to the TUI, we generally don't show the entire native error and instead replace it with something a bit less detailed.

utils.log_and_raise_error(
f"Could not make project folder at: {config_dict['local_path']}. "
f"Config file not updated.",
f"Error: {e}. Config file not updated.",
RuntimeError,
)

Expand Down
16 changes: 13 additions & 3 deletions datashuttle/configs/config_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
self.hostkeys_path: Path
self.ssh_key_path: Path
self.project_metadata_path: Path
self.aws_config_path: Path
self.gdrive_config_path: Path

def setup_after_load(self) -> None:
load_configs.convert_str_and_pathlib_paths(self, "str_to_path")
Expand Down Expand Up @@ -227,19 +229,27 @@ def make_rclone_transfer_options(
}

def init_paths(self) -> None:
""""""
"""
Initialize paths for configuration files and logs.
"""
self.project_metadata_path = self["local_path"] / ".datashuttle"

datashuttle_path, _ = canonical_folders.get_project_datashuttle_path(
self.project_name
)

self.ssh_key_path = datashuttle_path / f"{self.project_name}_ssh_key"

self.hostkeys_path = datashuttle_path / "hostkeys"

self.logging_path = self.make_and_get_logging_path()

# Add paths for AWS and Google Drive configuration
self.aws_config_path = (
datashuttle_path / f"{self.project_name}_aws_config"
)
self.gdrive_config_path = (
datashuttle_path / f"{self.project_name}_gdrive_config"
)

def make_and_get_logging_path(
self,
) -> Path:
Expand Down
195 changes: 176 additions & 19 deletions datashuttle/datashuttle_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from datashuttle.configs.config_class import Configs
from datashuttle.datashuttle_functions import _format_top_level_folder
from datashuttle.utils import (
aws,
ds_logger,
folders,
formatting,
Expand All @@ -53,6 +54,8 @@
from datashuttle.utils.decorators import ( # noqa
check_configs_set,
check_is_not_local_project,
requires_aws_configs,
requires_gdrive_configs,
requires_ssh_configs,
)

Expand Down Expand Up @@ -833,7 +836,7 @@ def _transfer_specific_file_or_folder(
utils.log(output.stderr.decode("utf-8"))

# -------------------------------------------------------------------------
# SSH
# Connection Setup (SSH, AWS, GDrive)
# -------------------------------------------------------------------------

@requires_ssh_configs
Expand Down Expand Up @@ -892,6 +895,109 @@ def write_public_key(self, filepath: str) -> None:
public.write(key.get_base64())
public.close()

@check_configs_set
@check_is_not_local_project
def setup_gdrive_connection(self) -> None:
"""
Guide user through Google Drive setup process.

This method provides instructions for setting up Rclone configuration
for Google Drive. Since Google Drive auth requires browser interaction,
this primarily provides guidance rather than automating the process.
"""
self._start_log("setup-gdrive-connection", local_vars=locals())

utils.log_and_message(
"Setting up Google Drive connection via Rclone..."
)

if not self.cfg["gdrive_folder_id"]:
utils.log_and_raise_error(
"Google Drive folder ID is not configured. Update config file with 'gdrive_folder_id'.",
ConfigError,
)

self._setup_rclone_central_gdrive_config(log=True)

utils.log_and_message(
"Once you've completed the interactive setup in your terminal, "
"your Google Drive connection should be ready to use."
)

ds_logger.close_log_filehandler()

def _setup_rclone_central_gdrive_config(self, log: bool = True) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this case, because this function just directly wraps rclones setup_rclone_config_for_gdrive with no other effects, calls to this function can be replaced with direct calls to setup_rclone_config_for_gdrive

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this uses the existing structure for the SSH case, which is good, but in this case the SSH structure is not ideal! So that could be factored out in a separate PR.

"""
Provide instructions for Google Drive Rclone setup.

This doesn't actually perform the setup, but provides the command
to run and instructions for the user.

Parameters
----------
log : bool
Whether to log the process
"""
rclone.setup_rclone_config_for_gdrive(
self.cfg,
self.cfg.get_rclone_config_name("gdrive"),
log=log,
)

@check_configs_set
@check_is_not_local_project
def setup_aws_connection(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very well structured, and it is easy to read alongside the gdrive function because they are all structured in a similar way

"""
Setup AWS S3 connection and configure Rclone.

This method verifies AWS credentials can access the specified bucket,
then sets up Rclone configuration for the project.

Assumes AWS credentials are already set up externally (environment
variables, AWS CLI config, instance profile, etc).
"""
self._start_log("setup-aws-connection", local_vars=locals())

utils.log_and_message("Setting up AWS S3 connection via Rclone...")

if not self.cfg["aws_bucket_name"]:
utils.log_and_raise_error(
"AWS bucket name is not configured. Update config file with 'aws_bucket_name'.",
ConfigError,
)

self._setup_rclone_central_aws_config(log=True)

verification_result = aws.verify_aws_credentials_with_logging(
self.cfg, message_on_successful_connection=True
)

if not verification_result:
utils.log_and_message(
"Failed to verify AWS connection. Check credentials and try again."
)
else:
utils.log_and_message("AWS S3 connection setup successful!")

ds_logger.close_log_filehandler()

def _setup_rclone_central_aws_config(self, log: bool = True) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here as for the gdrive function but with setup_rclone_config_for_aws

"""
Configure Rclone for AWS S3 access.

This creates/updates an rclone remote configuration for the AWS S3 bucket.

Parameters
----------
log : bool
Whether to log the process
"""
rclone.setup_rclone_config_for_aws(
self.cfg,
self.cfg.get_rclone_config_name("aws"),
log=log,
)

# -------------------------------------------------------------------------
# Configs
# -------------------------------------------------------------------------
Expand All @@ -903,6 +1009,9 @@ def make_config_file(
connection_method: str | None = None,
central_host_id: Optional[str] = None,
central_host_username: Optional[str] = None,
aws_bucket_name: Optional[str] = None,
aws_region: Optional[str] = None,
gdrive_folder_id: Optional[str] = None,
) -> None:
"""
Initialise the configurations for datashuttle to use on the
Expand All @@ -924,26 +1033,41 @@ def make_config_file(
path to project folder on local machine

central_path :
Filepath to central project.
If this is local (i.e. connection_method = "local_filesystem"),
this is the full path on the local filesystem
Otherwise, if this is via ssh (i.e. connection method = "ssh"),
this is the path to the project folder on central machine.
This should be a full path to central folder i.e. this cannot
include ~ home folder syntax, must contain the full path
(e.g. /nfs/nhome/live/jziminski)
Meaning depends on connection_method:
- local_filesystem: full filesystem path on local machine
- ssh: full path on SSH server
- aws: optional prefix path within the S3 bucket
- gdrive: optional path within the Google Drive folder
- None (local-only): ignored/not used

connection_method :
The method used to connect to the central project filesystem,
e.g. "local_filesystem" (e.g. mounted drive) or "ssh"
The method used to connect to the central project filesystem:
- "local_filesystem": mounted drive on local machine
- "ssh": remote SSH server
- "aws": Amazon S3 bucket
- "gdrive": Google Drive folder
- None: local-only mode (no central storage)

central_host_id :
server address for central host for ssh connection
e.g. "ssh.swc.ucl.ac.uk"
For SSH: server address for central host (e.g. "ssh.example.com")
For other connection methods: ignored

central_host_username :
username for which to log in to central host.
e.g. "jziminski"
For SSH: username to log in to central host (e.g. "username")
For other connection methods: ignored

aws_bucket_name :
For AWS: name of the S3 bucket to use
For other connection methods: ignored

aws_region :
For AWS: region of the S3 bucket (e.g. "us-east-1")
Optional: if None, uses default credentials region
For other connection methods: ignored

gdrive_folder_id :
For GDrive: ID of the Google Drive folder to use as root
For other connection methods: ignored
"""
self._start_log(
"make-config-file",
Expand All @@ -958,6 +1082,24 @@ def make_config_file(
RuntimeError,
)

if connection_method == "ssh" and (
not central_host_id or not central_host_username
):
utils.log_and_raise_error(
"SSH connection requires both 'central_host_id' and 'central_host_username'.",
ConfigError,
)
elif connection_method == "aws" and not aws_bucket_name:
utils.log_and_raise_error(
"AWS connection requires 'aws_bucket_name' to be specified.",
ConfigError,
)
elif connection_method == "gdrive" and not gdrive_folder_id:
utils.log_and_raise_error(
"Google Drive connection requires 'gdrive_folder_id' to be specified.",
ConfigError,
)

cfg = Configs(
self.project_name,
self._config_path,
Expand All @@ -967,19 +1109,34 @@ def make_config_file(
"connection_method": connection_method,
"central_host_id": central_host_id,
"central_host_username": central_host_username,
"aws_bucket_name": aws_bucket_name,
"aws_region": aws_region,
"gdrive_folder_id": gdrive_folder_id,
},
)

cfg.setup_after_load() # will raise error if fails
cfg.setup_after_load()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good point that this function has assert behavior that is only indicated by a stray comment here. It should be better to rename this function to indicate it sets up and asserts. Even better, because ideally each function will only do one thing, we could have two functions setup_after_load() and assert_config_dict_is_valid(), though the logic flow under the hood might make this tricky, at least renaming the function would help. this could be a separate PR.

self.cfg = cfg

self.cfg.dump_to_file()

self._set_attributes_after_config_load()

# This is just a placeholder rclone config that will suffice
# if central is a 'local filesystem'.
self._setup_rclone_central_local_filesystem_config()
if connection_method == "local_filesystem":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically the config set up for the transfer are done in the bespoke transfer-setup functions (i.e. setup_ssh_connection), the benefit is that local-only mode (i.e. no transfer) users don't have to worry about the setup. Ideally, the central path would not be set up here either, and the existing code was quite confusing. The problem is the local filesystem version is free to set up and there is no bespoke 'setup local filesystem' function, so we just do it here. But it is not ideal...

self._setup_rclone_central_local_filesystem_config()
elif connection_method == "ssh":
self._setup_rclone_central_ssh_config(log=True)
elif connection_method == "aws":
self._setup_rclone_central_aws_config(log=True)
elif connection_method == "gdrive":
utils.log_and_message(
"Google Drive connection requires interactive setup. "
"Run setup_gdrive_connection() after config initialization."
)
elif connection_method is None:
pass
else:
self._setup_rclone_central_local_filesystem_config()

utils.log_and_message(
"Configuration file has been saved and "
Expand Down
Loading
Loading