From 735131637424c9c4ac2f3b32fcd3c401b2619d69 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 21 Jan 2026 19:06:22 +0000 Subject: [PATCH 01/12] try setup het gs inference --- .../graph_store/heterogeneous_inference.py | 482 ++++++++++++++++++ 1 file changed, 482 insertions(+) create mode 100644 examples/link_prediction/graph_store/heterogeneous_inference.py diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py new file mode 100644 index 00000000..dddcdfd4 --- /dev/null +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -0,0 +1,482 @@ +""" +This file contains an example for how to run heterogeneous inference on pretrained torch.nn.Module in GiGL (or elsewhere) using new +GLT (GraphLearn-for-PyTorch) bindings that GiGL has. Note that example should be applied to use cases which already have +some pretrained `nn.Module` and are looking to utilize cost-savings with distributed inference. While `run_example_inference` is coupled with +GiGL orchestration, the `_inference_process` function is generic and can be used as references +for writing inference for pipelines not dependent on GiGL orchestration. + +To run this file with GiGL orchestration, set the fields similar to below: + +inferencerConfig: + inferencerArgs: + # Example argument to inferencer + log_every_n_batch: "50" + inferenceBatchSize: 512 + command: python -m examples.link_prediction.heterogeneous_inference +featureFlags: + should_run_glt_backend: 'True' + +You can run this example in a full pipeline with `make run_het_dblp_sup_test` from GiGL root. +""" + +import argparse +import gc +import sys +import time + +import torch +import torch.distributed +import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_heterogeneous_model + +import gigl.distributed +import gigl.distributed.utils +from gigl.common import GcsUri, UriFactory +from gigl.common.data.export import EmbeddingExporter, load_embeddings_to_bigquery +from gigl.common.logger import Logger +from gigl.common.utils.gcs import GcsUtils +from gigl.distributed.graph_store.compute import init_compute_process +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils import get_graph_store_info +from gigl.env.distributed import GraphStoreInfo +from gigl.nn import LinkPredictionGNN +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.common.utils.model import load_state_dict_from_uri +from gigl.src.inference.lib.assets import InferenceAssets +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +@torch.no_grad() +def _inference_process( + # When spawning processes, each process will be assigned a rank ranging + # from [0, num_processes). + local_rank: int, + local_world_size: int, + machine_rank: int, + machine_world_size: int, + cluster_info: GraphStoreInfo, + embedding_gcs_path: GcsUri, + model_state_dict_uri: GcsUri, + inference_batch_size: int, + hid_dim: int, + out_dim: int, + inferencer_args: dict[str, str], + inference_node_type: NodeType, + node_type_to_feature_dim: dict[NodeType, int], + edge_type_to_feature_dim: dict[EdgeType, int], + mp_sharing_dict: dict[str, torch.Tensor], +): + """ + This function is spawned by multiple processes per machine and is responsible for: + 1. Intializing the dataLoader + 2. Running the inference loop to get the embeddings for each anchor node + 3. Writing embeddings to GCS + + Args: + local_rank (int): Process number on the current machine + local_world_size (int): Number of inference processes spawned by each machine + machine_rank (int): Machine number in the distributed setup + machine_world_size (int): Total number of machines in the distributed setup + master_ip_address (str): IP address of the master node in the distributed setup + master_default_process_group_port (int): Port on the master node in the distributed setup to setup Torch process group on + embedding_gcs_path (GcsUri): GCS path to load embeddings from + model_state_dict_uri (GcsUri): GCS path to load model from + inference_batch_size (int): Batch size to use for inference + hid_dim (int): Hidden dimension of the model + out_dim (int): Output dimension of the model + dataset (DistDataset): Loaded Distributed Dataset for inference + inferencer_args (dict[str, str]): Additional arguments for inferencer + inference_node_type (NodeType): Node Type that embeddings should be generated for in current inference process. This is used to + tag the embeddings written to GCS. + node_type_to_feature_dim (dict[NodeType, int]): Input node feature dimension per node type for the model + edge_type_to_feature_dim (dict[EdgeType, int]): Input edge feature dimension per edge type for the model + mp_sharing_dict (dict[str, torch.Tensor]): Shared memory dictionary for sharing data between processes + """ + + # Parses the fanout as a string. + # For the heterogeneous case, the fanouts can be specified as a string of a list of integers, such as "[10, 10]", which will apply this fanout + # to each edge type in the graph, or as string of format dict[(tuple[str, str, str])), list[int]] which will specify fanouts per edge type. + # In the case of the latter, the keys should be specified with format (SRC_NODE_TYPE, RELATION, DST_NODE_TYPE). + # For the default example, we make a decision to keep the fanouts for all edge types the same, specifying the `fanout` with a `list[int]`. + # To see an example of a 'fanout' with different behaviors per edge type, refer to `examples/link_prediction.configs/e2e_het_dblp_sup_task_config.yaml`. + + fanout = inferencer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + # While the ideal value for `sampling_workers_per_inference_process` has been identified to be between `2` and `4`, this may need some tuning depending on the + # pipeline. We default this value to `4` here for simplicity. A `sampling_workers_per_process` which is too small may not have enough parallelization for + # sampling, which would slow down inference, while a value which is too large may slow down each sampling process due to competing resources, which would also + # then slow down inference. + sampling_workers_per_inference_process: int = int( + inferencer_args.get("sampling_workers_per_inference_process", "4") + ) + + # This value represents the the shared-memory buffer size (bytes) allocated for the channel during sampling, and + # is the place to store pre-fetched data, so if it is too small then prefetching is limited, causing sampling slowdown. This parameter is a string + # with `{numeric_value}{storage_size}`, where storage size could be `MB`, `GB`, etc. We default this value to 4GB, + # but in production may need some tuning. + sampling_worker_shared_channel_size: str = inferencer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) + device = gigl.distributed.utils.get_available_device( + local_process_rank=local_rank, + ) # The device is automatically inferred based off the local process rank and the available devices + rank = machine_rank * local_world_size + local_rank + world_size = machine_world_size * local_world_size + if torch.cuda.is_available(): + torch.cuda.set_device( + device + ) # Set the device for the current process. Without this, NCCL will fail when multiple GPUs are available. + # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster. + # If this is not done, the dataloader will not be able to sample from the graph store and will crash. + init_compute_process(local_rank, cluster_info) + dataset = RemoteDistDataset( + cluster_info, local_rank, mp_sharing_dict=mp_sharing_dict + ) + logger.info( + f"Local rank {local_rank} in machine {machine_rank} has rank {rank}/{world_size} and using device {device} for inference" + ) + + input_nodes = dataset.get_node_ids(inference_node_type) + data_loader = gigl.distributed.DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + # We must pass in a tuple of (node_type, node_ids_on_current_process) for heterogeneous input + input_nodes=(inference_node_type, input_nodes), + num_workers=sampling_workers_per_inference_process, + batch_size=inference_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_inference_process, + channel_size=sampling_worker_shared_channel_size, + # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders + # don't compete for memory during initialization, causing OOM + process_start_gap_seconds=0, + ) + # Initialize a LinkPredictionGNN model and load parameters from + # the saved model. + model_state_dict = load_state_dict_from_uri( + load_from_uri=model_state_dict_uri, device=device + ) + model: LinkPredictionGNN = init_example_gigl_heterogeneous_model( + node_type_to_feature_dim=node_type_to_feature_dim, + edge_type_to_feature_dim=edge_type_to_feature_dim, + hid_dim=hid_dim, + out_dim=out_dim, + device=device, + state_dict=model_state_dict, + ) + + # Set the model to evaluation mode for inference. + model.eval() + + logger.info(f"Model initialized on device {device}") + + embedding_filename = f"machine_{machine_rank}_local_process_number_{local_rank}" + + # Get temporary GCS folder to write outputs of inference to. GiGL orchestration automatic cleans this, but + # if running manually, you will need to clean this directory so that retries don't end up with stale files. + gcs_utils = GcsUtils() + gcs_base_uri = GcsUri.join(embedding_gcs_path, embedding_filename) + num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(gcs_base_uri) + if num_files_at_gcs_path > 0: + logger.warning( + f"{num_files_at_gcs_path} files already detected at base gcs path. Cleaning up files at path ... " + ) + gcs_utils.delete_files_in_bucket_dir(gcs_base_uri) + + # GiGL class for exporting embeddings to GCS. This is achieved by writing ids and embeddings to an in-memory buffer which gets + # flushed to GCS. Setting the min_shard_size_threshold_bytes field of this class sets the frequency of flushing to GCS, and defaults + # to only flushing when flush_records() is called explicitly or after exiting via a context manager. + exporter = EmbeddingExporter(export_dir=gcs_base_uri) + + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. + sys.stdout.flush() + # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph + # sampling may fail. + + torch.distributed.barrier() + + t = time.time() + data_loading_start_time = time.time() + inference_start_time = time.time() + cumulative_data_loading_time = 0.0 + cumulative_inference_time = 0.0 + + # Begin inference loop + + # Iterating through the dataloader yields a `torch_geometric.data.Data` type + for batch_idx, data in enumerate(data_loader): + cumulative_data_loading_time += time.time() - data_loading_start_time + + inference_start_time = time.time() + + # These arguments to forward are specific to the GiGL heterogeneous LinkPredictionGNN model. + # If just using a nn.Module, you can just use output = model(data) + output = model( + data=data, output_node_types=[inference_node_type], device=device + )[inference_node_type] + + # The anchor node IDs are contained inside of the .batch field of the data + node_ids = data[inference_node_type].batch.cpu() + + # Only the first `batch_size` rows of the node embeddings contain the embeddings of the anchor nodes + node_embeddings = output[: data[inference_node_type].batch_size].cpu() + + # We add ids and embeddings to the in-memory buffer + exporter.add_embedding( + id_batch=node_ids, + embedding_batch=node_embeddings, + embedding_type=str(inference_node_type), + ) + + cumulative_inference_time += time.time() - inference_start_time + + if batch_idx > 0 and batch_idx % log_every_n_batch == 0: + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. + sys.stdout.flush() + logger.info( + f"Rank {rank} processed {batch_idx} batches for node type {inference_node_type}. " + f"{log_every_n_batch} batches took {time.time() - t:.2f} seconds for node type {inference_node_type}. " + f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds." + f"and model inference took {cumulative_inference_time:.2f} seconds." + ) + t = time.time() + cumulative_data_loading_time = 0 + cumulative_inference_time = 0 + + data_loading_start_time = time.time() + + logger.info( + f"--- Rank {rank} finished inference for node type {inference_node_type}." + ) + + write_embedding_start_time = time.time() + # Flushes all remaining embeddings to GCS + exporter.flush_records() + + logger.info( + f"--- Rank {rank} finished writing embeddings to GCS for node type {inference_node_type}, which took {time.time()-write_embedding_start_time:.2f} seconds" + ) + + # We first call barrier to ensure that all machines and processes have finished inference. + # Only once all machines have finished inference is it safe to shutdown the data loader. + # Otherwise, processes which are still sampling *will* fail as the loaders they are trying to communicatate with will be shutdown. + # We then call `gc.collect()` to cleanup the memory used by the data_loader on the current machine. + + torch.distributed.barrier() + + data_loader.shutdown() + gc.collect() + + logger.info( + f"--- All machines local rank {local_rank} finished inference for node type {inference_node_type}. Deleted data loader" + ) + + +def _run_example_inference( + job_name: str, + task_config_uri: str, +) -> None: + """ + Runs an example inference pipeline using GiGL Orchestration. + Args: + job_name (str): Name of current job + task_config_uri (str): Path to frozen GBMLConfigPbWrapper + """ + # All machines run this logic to connect together, and return a distributed context with: + # - the (GCP) internal IP address of the rank 0 machine, which will be used for building RPC connections. + # - the current machine rank + # - the total number of machines (world size) + + program_start_time = time.time() + # The main process per machine needs to be able to talk with each other to partition and synchronize the graph data. + # Thus, the user is responsible here for 1. spinning up a single process per machine, + # and 2. init_process_group amongst these processes. + # Assuming this is spinning up inside VAI; it already sets up the env:// init method for us; thus we don't need anything + # special here. + torch.distributed.init_process_group(backend="gloo") + + logger.info( + f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" + ) + cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + torch.distributed.destroy_process_group() + + # Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path, and additional inference args + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + + node_type_to_feature_dim: dict[NodeType, int] = { + graph_metadata.condensed_node_type_to_node_type_map[ + condensed_node_type + ]: node_feature_dim + for condensed_node_type, node_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map.items() + } + + edge_type_to_feature_dim: dict[EdgeType, int] = { + graph_metadata.condensed_edge_type_to_edge_type_map[ + condensed_edge_type + ]: edge_feature_dim + for condensed_edge_type, edge_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map.items() + } + + inference_node_types = sorted( + gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_task_root_node_types() + ) + + inferencer_args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args) + + inference_batch_size = gbml_config_pb_wrapper.inferencer_config.inference_batch_size + + hid_dim = int(inferencer_args.get("hid_dim", "16")) + out_dim = int(inferencer_args.get("out_dim", "16")) + + if torch.cuda.is_available(): + default_num_inference_processes_per_machine = torch.cuda.device_count() + else: + default_num_inference_processes_per_machine = 2 + num_inference_processes_per_machine = int( + inferencer_args.get( + "num_inference_processes_per_machine", + default_num_inference_processes_per_machine, + ) + ) # Current large-scale setting sets this value to 4 + + if ( + torch.cuda.is_available() + and num_inference_processes_per_machine > torch.cuda.device_count() + ): + raise ValueError( + f"Number of inference processes per machine ({num_inference_processes_per_machine}) must not be more than the number of GPUs: ({torch.cuda.device_count()})" + ) + + master_ip_address = gigl.distributed.utils.get_internal_ip_from_master_node() + machine_rank = torch.distributed.get_rank() + machine_world_size = torch.distributed.get_world_size() + master_default_process_group_port = ( + gigl.distributed.utils.get_free_ports_from_master_node(num_ports=1)[0] + ) + + ## Inference Start + + inference_start_time = time.time() + + for process_num, inference_node_type in enumerate(inference_node_types): + logger.info( + f"Starting inference process for node type {inference_node_type} ..." + ) + output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( + gbml_config_pb_wrapper, inference_node_type + ) + + bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( + bq_table_path=output_bq_table_path + ) + + # We write embeddings to a temporary GCS path during the inference loop, since writing directly to bigquery for each embedding is slow. + # After inference has finished, we then load all embeddings to bigquery from GCS. + embedding_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( + applied_task_identifier=AppliedTaskIdentifier(job_name), + bq_table_path=output_bq_table_path, + ) + mp_sharing_dict = mp.Manager().dict() + if cluster_info.compute_node_rank == 0: + gcs_utils = GcsUtils() + num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path( + embedding_output_gcs_folder + ) + if num_files_at_gcs_path > 0: + logger.warning( + f"{num_files_at_gcs_path} files already detected at base gcs path {embedding_output_gcs_folder}. Cleaning up files at path ... " + ) + gcs_utils.delete_files_in_bucket_dir(embedding_output_gcs_folder) + + # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. + mp.spawn( + fn=_inference_process, + args=( + num_inference_processes_per_machine, # local_world_size + machine_rank, # machine_rank + cluster_info, # cluster_info + master_default_process_group_port, # master_default_process_group_port + embedding_output_gcs_folder, # embedding_gcs_path + model_uri, # model_state_dict_uri + inference_batch_size, # inference_batch_size + hid_dim, # hid_dim + out_dim, # out_dim + inferencer_args, # inferencer_args + inference_node_type, # inference_node_type + node_type_to_feature_dim, # node_type_to_feature_dim + edge_type_to_feature_dim, # edge_type_to_feature_dim + mp_sharing_dict, # mp_sharing_dict + ), + nprocs=num_inference_processes_per_machine, + join=True, + ) + + logger.info( + f"--- Inference finished on rank {machine_rank} for node type {inference_node_type}, which took {time.time()-inference_start_time:.2f} seconds" + ) + + # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. + if machine_rank == 0: + logger.info( + f"--- Machine 0 triggers loading embeddings from GCS to BigQuery for node type {inference_node_type}" + ) + # If we are on the last inference process, we should wait for this last write process to complete. Otherwise, we should + # load embeddings to bigquery in the background so that we are not blocking the start of the next inference process + should_run_async = process_num != len(inference_node_types) - 1 + + # The `load_embeddings_to_bigquery` API returns a BigQuery LoadJob object + # representing the load operation, which allows user to monitor and retrieve + # details about the job status and result. + _ = load_embeddings_to_bigquery( + gcs_folder=embedding_output_gcs_folder, + project_id=bq_project_id, + dataset_id=bq_dataset_id, + table_id=bq_table_name, + should_run_async=should_run_async, + ) + + logger.info( + f"--- Program finished, which took {time.time()-program_start_time:.2f} seconds" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model inference on VertexAI" + ) + parser.add_argument( + "--job_name", + type=str, + help="Inference job name", + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") + + # We only need `job_name` and `task_config_uri` for running inference + _run_example_inference( + job_name=args.job_name, + task_config_uri=args.task_config_uri, + ) From 3100d7532143cd184057b75e2dd7d9325f413681 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 21 Jan 2026 19:07:55 +0000 Subject: [PATCH 02/12] configs --- Makefile | 8 ++ .../configs/e2e_het_dblp_sup_task_config.yaml | 76 +++++++++++++++++++ testing/e2e_tests/e2e_tests.yaml | 3 + 3 files changed, 87 insertions(+) create mode 100644 examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_task_config.yaml diff --git a/Makefile b/Makefile index e8914b4f..952ab77e 100644 --- a/Makefile +++ b/Makefile @@ -245,6 +245,14 @@ run_hom_cora_sup_gs_e2e_test: --test_spec_uri="testing/e2e_tests/e2e_tests.yaml" \ --test_names="hom_cora_sup_gs_test" +run_het_dblp_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} +run_het_dblp_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline +run_het_dblp_sup_gs_e2e_test: + uv run python testing/e2e_tests/e2e_test.py \ + --compiled_pipeline_path=$(compiled_pipeline_path) \ + --test_spec_uri="testing/e2e_tests/e2e_tests.yaml" \ + --test_names="het_dblp_sup_gs_test" + run_all_e2e_tests: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} run_all_e2e_tests: compile_gigl_kubeflow_pipeline run_all_e2e_tests: diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_task_config.yaml new file mode 100644 index 00000000..5f848ed5 --- /dev/null +++ b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_task_config.yaml @@ -0,0 +1,76 @@ +# This config is used to run heterogeneous DBLP self-supervised training and inference using in memory GiGL SGS. This can be run with `make run_het_dblp_sup_test`. +graphMetadata: + # We have 3 nodes types in the DBLP Dataset: author, paper, and term. We also have 3 + # edge types: author -> paper, paper -> author, and term -> paper + edgeTypes: + - dstNodeType: paper + relation: to + srcNodeType: author + - dstNodeType: author + relation: to + srcNodeType: paper + - dstNodeType: paper + relation: to + srcNodeType: term + nodeTypes: + - author + - paper + - term +taskMetadata: + nodeAnchorBasedLinkPredictionTaskMetadata: + # We aim to predict paper -> author links in the graph. + supervisionEdgeTypes: + - dstNodeType: author + relation: to + srcNodeType: paper +datasetConfig: + dataPreprocessorConfig: + dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets + dataPreprocessorArgs: + # This argument is specific for the `PassthroughPreprocessorConfigForMockedAssets` preprocessor to indicate which dataset we should be using + mocked_dataset_name: 'dblp_node_anchor_edge_features_lp' +# TODO(kmonte): Add GS trainer +trainerConfig: + trainerArgs: + # Example argument to trainer + log_every_n_batch: "50" + # The DBLP Dataset does not have specified labeled edges so we provide this field to indicate what + # percentage of edges we should select as self-supervised labeled edges. Doing this randomly sets 5% as "labels". + # Note that the current GiGL implementation does not remove these selected edges from the global set of edges, which may + # have a slight negative impact on training specifically with self-supervised learning. This will improved on in the future. + ssl_positive_label_percentage: "0.05" + # Example of a dictionary fanout which has different fanout-per-hop for each edge type. Currently, we assume that all anchor node types + # use the same fanout. If you want different anchor node types to have different fanouts, we encourage adding additional arguemnts here to parse + # fanouts for each anchor node type. + # Note that edge types must be provided as a tuple[str, str, str] in format (SRC_NODE_TYPE, RELATION, DST_NODE_TYPE), as demonstrated below. + num_neighbors: >- + { + ("term", "to", "paper"): [10, 10], + ("paper", "to", "author"): [15, 15], + ("author", "to", "paper"): [20, 20] + } + command: python -m examples.link_prediction.heterogeneous_training +# TODO(kmonte): Move to user-defined server code +inferencerConfig: + inferencerArgs: + # Example argument to inferencer + log_every_n_batch: "50" + # Example of a dictionary fanout which has different fanout-per-hop for each edge type. Currently, we assume that all anchor node types + # use the same fanout. If you want different anchor node types to have different fanouts, we encourage adding additional arguemnts here to parse + # fanouts for each anchor node type. + # Note that edge types must be provided as a tuple[str, str, str] in format (SRC_NODE_TYPE, RELATION, DST_NODE_TYPE), as demonstrated below. + num_neighbors: >- + { + ("term", "to", "paper"): [10, 10], + ("paper", "to", "author"): [15, 15], + ("author", "to", "paper"): [20, 20] + } + inferenceBatchSize: 512 + command: python -m examples.link_prediction.graph_store.heterogeneous_inference +sharedConfig: + shouldSkipAutomaticTempAssetCleanup: false + shouldSkipInference: false + # Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines. + shouldSkipModelEvaluation: true +featureFlags: + should_run_glt_backend: 'True' diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index b2d8517b..ca6afd28 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -22,3 +22,6 @@ tests: hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" + het_dblp_sup_gs_test: + task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From 4b54246d2e9f4d599552f854dc9a5335256a6419 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 21 Jan 2026 21:52:37 +0000 Subject: [PATCH 03/12] fix name --- ..._sup_task_config.yaml => e2e_het_dblp_sup_gs_task_config.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/link_prediction/graph_store/configs/{e2e_het_dblp_sup_task_config.yaml => e2e_het_dblp_sup_gs_task_config.yaml} (100%) diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml similarity index 100% rename from examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_task_config.yaml rename to examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml From cef4da0cba575125327108d6d591e71226cacc1f Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 21 Jan 2026 23:45:01 +0000 Subject: [PATCH 04/12] hmmm --- .../graph_store/heterogeneous_inference.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index dddcdfd4..16175e33 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -57,8 +57,6 @@ def _inference_process( # from [0, num_processes). local_rank: int, local_world_size: int, - machine_rank: int, - machine_world_size: int, cluster_info: GraphStoreInfo, embedding_gcs_path: GcsUri, model_state_dict_uri: GcsUri, @@ -80,10 +78,7 @@ def _inference_process( Args: local_rank (int): Process number on the current machine local_world_size (int): Number of inference processes spawned by each machine - machine_rank (int): Machine number in the distributed setup - machine_world_size (int): Total number of machines in the distributed setup - master_ip_address (str): IP address of the master node in the distributed setup - master_default_process_group_port (int): Port on the master node in the distributed setup to setup Torch process group on + cluster_info (GraphStoreInfo): Cluster information embedding_gcs_path (GcsUri): GCS path to load embeddings from model_state_dict_uri (GcsUri): GCS path to load model from inference_batch_size (int): Batch size to use for inference @@ -128,8 +123,6 @@ def _inference_process( device = gigl.distributed.utils.get_available_device( local_process_rank=local_rank, ) # The device is automatically inferred based off the local process rank and the available devices - rank = machine_rank * local_world_size + local_rank - world_size = machine_world_size * local_world_size if torch.cuda.is_available(): torch.cuda.set_device( device @@ -140,6 +133,9 @@ def _inference_process( dataset = RemoteDistDataset( cluster_info, local_rank, mp_sharing_dict=mp_sharing_dict ) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + machine_rank = cluster_info.compute_node_rank logger.info( f"Local rank {local_rank} in machine {machine_rank} has rank {rank}/{world_size} and using device {device} for inference" ) @@ -367,13 +363,6 @@ def _run_example_inference( f"Number of inference processes per machine ({num_inference_processes_per_machine}) must not be more than the number of GPUs: ({torch.cuda.device_count()})" ) - master_ip_address = gigl.distributed.utils.get_internal_ip_from_master_node() - machine_rank = torch.distributed.get_rank() - machine_world_size = torch.distributed.get_world_size() - master_default_process_group_port = ( - gigl.distributed.utils.get_free_ports_from_master_node(num_ports=1)[0] - ) - ## Inference Start inference_start_time = time.time() @@ -413,9 +402,7 @@ def _run_example_inference( fn=_inference_process, args=( num_inference_processes_per_machine, # local_world_size - machine_rank, # machine_rank cluster_info, # cluster_info - master_default_process_group_port, # master_default_process_group_port embedding_output_gcs_folder, # embedding_gcs_path model_uri, # model_state_dict_uri inference_batch_size, # inference_batch_size @@ -432,11 +419,11 @@ def _run_example_inference( ) logger.info( - f"--- Inference finished on rank {machine_rank} for node type {inference_node_type}, which took {time.time()-inference_start_time:.2f} seconds" + f"--- Inference finished on rank {cluster_info.compute_node_rank} for node type {inference_node_type}, which took {time.time()-inference_start_time:.2f} seconds" ) # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. - if machine_rank == 0: + if cluster_info.compute_node_rank == 0: logger.info( f"--- Machine 0 triggers loading embeddings from GCS to BigQuery for node type {inference_node_type}" ) From 3e1709f03eb3d75447fb17aef1f4a46c1b07a57d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 21 Jan 2026 23:45:24 +0000 Subject: [PATCH 05/12] avoid scala builkds --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 952ab77e..dd7d72d8 100644 --- a/Makefile +++ b/Makefile @@ -291,7 +291,7 @@ _skip_build_deps: # job_name=... \ , and other params # compiled_pipeline_path="/tmp/gigl/my_pipeline.yaml" \ # run_dev_gnn_kubeflow_pipeline -run_dev_gnn_kubeflow_pipeline: $(if $(compiled_pipeline_path), _skip_build_deps, compile_jars push_new_docker_images) +run_dev_gnn_kubeflow_pipeline: $(if $(compiled_pipeline_path), _skip_build_deps, push_new_docker_images) uv run python -m gigl.orchestration.kubeflow.runner \ $(if $(compiled_pipeline_path),,--container_image_cuda=${DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG}) \ $(if $(compiled_pipeline_path),,--container_image_cpu=${DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG}) \ From b2e849fcf27c07ade4c7117860a31626f7808450 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 22 Jan 2026 17:25:01 +0000 Subject: [PATCH 06/12] idk --- .../graph_store/heterogeneous_inference.py | 8 ++++++-- python/gigl/distributed/distributed_neighborloader.py | 4 +++- .../graph_store/graph_store_integration_test.py | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index 16175e33..22f4ace2 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -99,7 +99,7 @@ def _inference_process( # In the case of the latter, the keys should be specified with format (SRC_NODE_TYPE, RELATION, DST_NODE_TYPE). # For the default example, we make a decision to keep the fanouts for all edge types the same, specifying the `fanout` with a `list[int]`. # To see an example of a 'fanout' with different behaviors per edge type, refer to `examples/link_prediction.configs/e2e_het_dblp_sup_task_config.yaml`. - + print(f"Rank {local_rank} doing inference for node type {inference_node_type}") fanout = inferencer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) @@ -141,6 +141,7 @@ def _inference_process( ) input_nodes = dataset.get_node_ids(inference_node_type) + sys.stdout.flush() data_loader = gigl.distributed.DistNeighborLoader( dataset=dataset, num_neighbors=num_neighbors, @@ -155,6 +156,8 @@ def _inference_process( # don't compete for memory during initialization, causing OOM process_start_gap_seconds=0, ) + print(f"Rank {local_rank} initialized the data loader for node type {inference_node_type}") + sys.stdout.flush() # Initialize a LinkPredictionGNN model and load parameters from # the saved model. model_state_dict = load_state_dict_from_uri( @@ -277,6 +280,7 @@ def _inference_process( f"--- All machines local rank {local_rank} finished inference for node type {inference_node_type}. Deleted data loader" ) + sys.stdout.flush() def _run_example_inference( job_name: str, @@ -396,7 +400,7 @@ def _run_example_inference( f"{num_files_at_gcs_path} files already detected at base gcs path {embedding_output_gcs_folder}. Cleaning up files at path ... " ) gcs_utils.delete_files_in_bucket_dir(embedding_output_gcs_folder) - + sys.stdout.flush() # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. mp.spawn( fn=_inference_process, diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 1125f1f9..664112b5 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -336,9 +336,11 @@ def __init__( device, worker_options, ) - print(f"node_rank {node_rank} initialized the dist loader") + logger.info(f"node_rank {node_rank} initialized the dist loader") torch.distributed.barrier() torch.distributed.barrier() + logger.info("All node ranks initialized the dist loader") + def _setup_for_graph_store( self, diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index ab127b7e..dfed4a62 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -255,7 +255,7 @@ def _get_expected_input_nodes_by_rank( class GraphStoreIntegrationTest(unittest.TestCase): - def test_graph_store_homogeneous(self): + def _test_graph_store_homogeneous(self): # Simulating two server machine, two compute machines. # Each machine has one process. cora_supervised_info = get_mocked_dataset_artifact_metadata()[ @@ -351,7 +351,7 @@ def test_graph_store_homogeneous(self): server_process.join() # TODO: (mkolodner-sc) - Figure out why this test is failing on Google Cloud Build - @unittest.skip("Failing on Google Cloud Build - skiping for now") + #@unittest.skip("Failing on Google Cloud Build - skiping for now") def test_graph_store_heterogeneous(self): # Simulating two server machine, two compute machines. # Each machine has one process. From 443a899f63039da341a9b8c194bd5a4b74f78222 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 27 Jan 2026 19:37:41 +0000 Subject: [PATCH 07/12] maybe no loops --- .../graph_store/heterogeneous_inference.py | 368 +++++++++--------- 1 file changed, 179 insertions(+), 189 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index 22f4ace2..d9818bca 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -58,7 +58,6 @@ def _inference_process( local_rank: int, local_world_size: int, cluster_info: GraphStoreInfo, - embedding_gcs_path: GcsUri, model_state_dict_uri: GcsUri, inference_batch_size: int, hid_dim: int, @@ -68,6 +67,9 @@ def _inference_process( node_type_to_feature_dim: dict[NodeType, int], edge_type_to_feature_dim: dict[EdgeType, int], mp_sharing_dict: dict[str, torch.Tensor], + inference_node_types: list[NodeType], + gbml_config_pb_wrapper: GbmlConfigPbWrapper, + job_name: str, ): """ This function is spawned by multiple processes per machine and is responsible for: @@ -91,6 +93,9 @@ def _inference_process( node_type_to_feature_dim (dict[NodeType, int]): Input node feature dimension per node type for the model edge_type_to_feature_dim (dict[EdgeType, int]): Input edge feature dimension per edge type for the model mp_sharing_dict (dict[str, torch.Tensor]): Shared memory dictionary for sharing data between processes + inference_node_types (list[NodeType]): List of node types to generate embeddings for + gbml_config_pb_wrapper (GbmlConfigPbWrapper): GBML config wrapper + job_name (str): Name of current job """ # Parses the fanout as a string. @@ -139,148 +144,183 @@ def _inference_process( logger.info( f"Local rank {local_rank} in machine {machine_rank} has rank {rank}/{world_size} and using device {device} for inference" ) + for inference_node_type in inference_node_types: + logger.info(f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} starting inference for node type {inference_node_type}") + output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( + gbml_config_pb_wrapper, inference_node_type + ) - input_nodes = dataset.get_node_ids(inference_node_type) - sys.stdout.flush() - data_loader = gigl.distributed.DistNeighborLoader( - dataset=dataset, - num_neighbors=num_neighbors, - # We must pass in a tuple of (node_type, node_ids_on_current_process) for heterogeneous input - input_nodes=(inference_node_type, input_nodes), - num_workers=sampling_workers_per_inference_process, - batch_size=inference_batch_size, - pin_memory_device=device, - worker_concurrency=sampling_workers_per_inference_process, - channel_size=sampling_worker_shared_channel_size, - # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders - # don't compete for memory during initialization, causing OOM - process_start_gap_seconds=0, - ) - print(f"Rank {local_rank} initialized the data loader for node type {inference_node_type}") - sys.stdout.flush() - # Initialize a LinkPredictionGNN model and load parameters from - # the saved model. - model_state_dict = load_state_dict_from_uri( - load_from_uri=model_state_dict_uri, device=device - ) - model: LinkPredictionGNN = init_example_gigl_heterogeneous_model( - node_type_to_feature_dim=node_type_to_feature_dim, - edge_type_to_feature_dim=edge_type_to_feature_dim, - hid_dim=hid_dim, - out_dim=out_dim, - device=device, - state_dict=model_state_dict, - ) + bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( + bq_table_path=output_bq_table_path + ) - # Set the model to evaluation mode for inference. - model.eval() + # We write embeddings to a temporary GCS path during the inference loop, since writing directly to bigquery for each embedding is slow. + # After inference has finished, we then load all embeddings to bigquery from GCS. + embedding_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( + applied_task_identifier=AppliedTaskIdentifier(job_name), + bq_table_path=output_bq_table_path, + ) + input_nodes = dataset.get_node_ids(node_type=inference_node_type) + sys.stdout.flush() + data_loader = gigl.distributed.DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + # We must pass in a tuple of (node_type, node_ids_on_current_process) for heterogeneous input + input_nodes=(inference_node_type, input_nodes), + num_workers=sampling_workers_per_inference_process, + batch_size=inference_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_inference_process, + channel_size=sampling_worker_shared_channel_size, + # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders + # don't compete for memory during initialization, causing OOM + process_start_gap_seconds=0, + ) + print(f"Rank {local_rank} initialized the data loader for node type {inference_node_type}") + sys.stdout.flush() + # Initialize a LinkPredictionGNN model and load parameters from + # the saved model. + model_state_dict = load_state_dict_from_uri( + load_from_uri=model_state_dict_uri, device=device + ) + model: LinkPredictionGNN = init_example_gigl_heterogeneous_model( + node_type_to_feature_dim=node_type_to_feature_dim, + edge_type_to_feature_dim=edge_type_to_feature_dim, + hid_dim=hid_dim, + out_dim=out_dim, + device=device, + state_dict=model_state_dict, + ) - logger.info(f"Model initialized on device {device}") + # Set the model to evaluation mode for inference. + model.eval() - embedding_filename = f"machine_{machine_rank}_local_process_number_{local_rank}" + logger.info(f"Model initialized on device {device}") - # Get temporary GCS folder to write outputs of inference to. GiGL orchestration automatic cleans this, but - # if running manually, you will need to clean this directory so that retries don't end up with stale files. - gcs_utils = GcsUtils() - gcs_base_uri = GcsUri.join(embedding_gcs_path, embedding_filename) - num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(gcs_base_uri) - if num_files_at_gcs_path > 0: - logger.warning( - f"{num_files_at_gcs_path} files already detected at base gcs path. Cleaning up files at path ... " - ) - gcs_utils.delete_files_in_bucket_dir(gcs_base_uri) + embedding_filename = f"machine_{machine_rank}_local_process_number_{local_rank}" - # GiGL class for exporting embeddings to GCS. This is achieved by writing ids and embeddings to an in-memory buffer which gets - # flushed to GCS. Setting the min_shard_size_threshold_bytes field of this class sets the frequency of flushing to GCS, and defaults - # to only flushing when flush_records() is called explicitly or after exiting via a context manager. - exporter = EmbeddingExporter(export_dir=gcs_base_uri) + # Get temporary GCS folder to write outputs of inference to. GiGL orchestration automatic cleans this, but + # if running manually, you will need to clean this directory so that retries don't end up with stale files. + gcs_utils = GcsUtils() + gcs_base_uri = GcsUri.join(embedding_output_gcs_folder, embedding_filename) + num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(gcs_base_uri) + if num_files_at_gcs_path > 0: + logger.warning( + f"{num_files_at_gcs_path} files already detected at base gcs path. Cleaning up files at path ... " + ) + gcs_utils.delete_files_in_bucket_dir(gcs_base_uri) - # We don't see logs for graph store mode for whatever reason. - # TOOD(#442): Revert this once the GCP issues are resolved. - sys.stdout.flush() - # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph - # sampling may fail. + # GiGL class for exporting embeddings to GCS. This is achieved by writing ids and embeddings to an in-memory buffer which gets + # flushed to GCS. Setting the min_shard_size_threshold_bytes field of this class sets the frequency of flushing to GCS, and defaults + # to only flushing when flush_records() is called explicitly or after exiting via a context manager. + exporter = EmbeddingExporter(export_dir=gcs_base_uri) - torch.distributed.barrier() + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. + sys.stdout.flush() + # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph + # sampling may fail. - t = time.time() - data_loading_start_time = time.time() - inference_start_time = time.time() - cumulative_data_loading_time = 0.0 - cumulative_inference_time = 0.0 + torch.distributed.barrier() - # Begin inference loop + t = time.time() + data_loading_start_time = time.time() + inference_start_time = time.time() + cumulative_data_loading_time = 0.0 + cumulative_inference_time = 0.0 - # Iterating through the dataloader yields a `torch_geometric.data.Data` type - for batch_idx, data in enumerate(data_loader): - cumulative_data_loading_time += time.time() - data_loading_start_time + # Begin inference loop - inference_start_time = time.time() + # Iterating through the dataloader yields a `torch_geometric.data.Data` type + for batch_idx, data in enumerate(data_loader): + cumulative_data_loading_time += time.time() - data_loading_start_time + + inference_start_time = time.time() + + # These arguments to forward are specific to the GiGL heterogeneous LinkPredictionGNN model. + # If just using a nn.Module, you can just use output = model(data) + output = model( + data=data, output_node_types=[inference_node_type], device=device + )[inference_node_type] + + # The anchor node IDs are contained inside of the .batch field of the data + node_ids = data[inference_node_type].batch.cpu() + + # Only the first `batch_size` rows of the node embeddings contain the embeddings of the anchor nodes + node_embeddings = output[: data[inference_node_type].batch_size].cpu() - # These arguments to forward are specific to the GiGL heterogeneous LinkPredictionGNN model. - # If just using a nn.Module, you can just use output = model(data) - output = model( - data=data, output_node_types=[inference_node_type], device=device - )[inference_node_type] + # We add ids and embeddings to the in-memory buffer + exporter.add_embedding( + id_batch=node_ids, + embedding_batch=node_embeddings, + embedding_type=str(inference_node_type), + ) - # The anchor node IDs are contained inside of the .batch field of the data - node_ids = data[inference_node_type].batch.cpu() + cumulative_inference_time += time.time() - inference_start_time + + if batch_idx > 0 and batch_idx % log_every_n_batch == 0: + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. + sys.stdout.flush() + logger.info( + f"Rank {rank} processed {batch_idx} batches for node type {inference_node_type}. " + f"{log_every_n_batch} batches took {time.time() - t:.2f} seconds for node type {inference_node_type}. " + f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds." + f"and model inference took {cumulative_inference_time:.2f} seconds." + ) + t = time.time() + cumulative_data_loading_time = 0 + cumulative_inference_time = 0 - # Only the first `batch_size` rows of the node embeddings contain the embeddings of the anchor nodes - node_embeddings = output[: data[inference_node_type].batch_size].cpu() + data_loading_start_time = time.time() - # We add ids and embeddings to the in-memory buffer - exporter.add_embedding( - id_batch=node_ids, - embedding_batch=node_embeddings, - embedding_type=str(inference_node_type), + logger.info( + f"--- Rank {rank} finished inference for node type {inference_node_type}." ) - cumulative_inference_time += time.time() - inference_start_time + write_embedding_start_time = time.time() + # Flushes all remaining embeddings to GCS + exporter.flush_records() - if batch_idx > 0 and batch_idx % log_every_n_batch == 0: - # We don't see logs for graph store mode for whatever reason. - # TOOD(#442): Revert this once the GCP issues are resolved. - sys.stdout.flush() + logger.info( + f"--- Rank {rank} finished writing embeddings to GCS for node type {inference_node_type}, which took {time.time()-write_embedding_start_time:.2f} seconds" + ) + # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. + if cluster_info.compute_node_rank == 0: logger.info( - f"Rank {rank} processed {batch_idx} batches for node type {inference_node_type}. " - f"{log_every_n_batch} batches took {time.time() - t:.2f} seconds for node type {inference_node_type}. " - f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds." - f"and model inference took {cumulative_inference_time:.2f} seconds." + f"--- Machine 0 triggers loading embeddings from GCS to BigQuery for node type {inference_node_type}" ) - t = time.time() - cumulative_data_loading_time = 0 - cumulative_inference_time = 0 - - data_loading_start_time = time.time() - - logger.info( - f"--- Rank {rank} finished inference for node type {inference_node_type}." - ) + # If we are on the last inference process, we should wait for this last write process to complete. Otherwise, we should + # load embeddings to bigquery in the background so that we are not blocking the start of the next inference process + should_run_async = local_rank != local_world_size - 1 - write_embedding_start_time = time.time() - # Flushes all remaining embeddings to GCS - exporter.flush_records() + # The `load_embeddings_to_bigquery` API returns a BigQuery LoadJob object + # representing the load operation, which allows user to monitor and retrieve + # details about the job status and result. + _ = load_embeddings_to_bigquery( + gcs_folder=embedding_output_gcs_folder, + project_id=bq_project_id, + dataset_id=bq_dataset_id, + table_id=bq_table_name, + should_run_async=should_run_async, + ) - logger.info( - f"--- Rank {rank} finished writing embeddings to GCS for node type {inference_node_type}, which took {time.time()-write_embedding_start_time:.2f} seconds" - ) - # We first call barrier to ensure that all machines and processes have finished inference. - # Only once all machines have finished inference is it safe to shutdown the data loader. - # Otherwise, processes which are still sampling *will* fail as the loaders they are trying to communicatate with will be shutdown. - # We then call `gc.collect()` to cleanup the memory used by the data_loader on the current machine. + # We first call barrier to ensure that all machines and processes have finished inference. + # Only once all machines have finished inference is it safe to shutdown the data loader. + # Otherwise, processes which are still sampling *will* fail as the loaders they are trying to communicatate with will be shutdown. + # We then call `gc.collect()` to cleanup the memory used by the data_loader on the current machine. - torch.distributed.barrier() + torch.distributed.barrier() - data_loader.shutdown() - gc.collect() + data_loader.shutdown() + gc.collect() - logger.info( - f"--- All machines local rank {local_rank} finished inference for node type {inference_node_type}. Deleted data loader" - ) + logger.info( + f"--- All machines local rank {local_rank} finished inference for node type {inference_node_type}. Deleted data loader" + ) - sys.stdout.flush() + sys.stdout.flush() def _run_example_inference( job_name: str, @@ -371,80 +411,30 @@ def _run_example_inference( inference_start_time = time.time() - for process_num, inference_node_type in enumerate(inference_node_types): - logger.info( - f"Starting inference process for node type {inference_node_type} ..." - ) - output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( - gbml_config_pb_wrapper, inference_node_type - ) - - bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( - bq_table_path=output_bq_table_path - ) - - # We write embeddings to a temporary GCS path during the inference loop, since writing directly to bigquery for each embedding is slow. - # After inference has finished, we then load all embeddings to bigquery from GCS. - embedding_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( - applied_task_identifier=AppliedTaskIdentifier(job_name), - bq_table_path=output_bq_table_path, - ) - mp_sharing_dict = mp.Manager().dict() - if cluster_info.compute_node_rank == 0: - gcs_utils = GcsUtils() - num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path( - embedding_output_gcs_folder - ) - if num_files_at_gcs_path > 0: - logger.warning( - f"{num_files_at_gcs_path} files already detected at base gcs path {embedding_output_gcs_folder}. Cleaning up files at path ... " - ) - gcs_utils.delete_files_in_bucket_dir(embedding_output_gcs_folder) - sys.stdout.flush() - # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. - mp.spawn( - fn=_inference_process, - args=( - num_inference_processes_per_machine, # local_world_size - cluster_info, # cluster_info - embedding_output_gcs_folder, # embedding_gcs_path - model_uri, # model_state_dict_uri - inference_batch_size, # inference_batch_size - hid_dim, # hid_dim - out_dim, # out_dim - inferencer_args, # inferencer_args - inference_node_type, # inference_node_type - node_type_to_feature_dim, # node_type_to_feature_dim - edge_type_to_feature_dim, # edge_type_to_feature_dim - mp_sharing_dict, # mp_sharing_dict - ), - nprocs=num_inference_processes_per_machine, - join=True, - ) - - logger.info( - f"--- Inference finished on rank {cluster_info.compute_node_rank} for node type {inference_node_type}, which took {time.time()-inference_start_time:.2f} seconds" - ) - - # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. - if cluster_info.compute_node_rank == 0: - logger.info( - f"--- Machine 0 triggers loading embeddings from GCS to BigQuery for node type {inference_node_type}" - ) - # If we are on the last inference process, we should wait for this last write process to complete. Otherwise, we should - # load embeddings to bigquery in the background so that we are not blocking the start of the next inference process - should_run_async = process_num != len(inference_node_types) - 1 - - # The `load_embeddings_to_bigquery` API returns a BigQuery LoadJob object - # representing the load operation, which allows user to monitor and retrieve - # details about the job status and result. - _ = load_embeddings_to_bigquery( - gcs_folder=embedding_output_gcs_folder, - project_id=bq_project_id, - dataset_id=bq_dataset_id, - table_id=bq_table_name, - should_run_async=should_run_async, - ) + mp_sharing_dict = mp.Manager().dict() + sys.stdout.flush() + # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. + mp.spawn( + fn=_inference_process, + args=( + num_inference_processes_per_machine, # local_world_size + cluster_info, # cluster_info + model_uri, # model_state_dict_uri + inference_batch_size, # inference_batch_size + hid_dim, # hid_dim + out_dim, # out_dim + inferencer_args, # inferencer_args + inference_node_types, # inference_node_types + node_type_to_feature_dim, # node_type_to_feature_dim + edge_type_to_feature_dim, # edge_type_to_feature_dim + mp_sharing_dict, # mp_sharing_dict + inference_node_types, # inference_node_types + gbml_config_pb_wrapper, # gbml_config_pb_wrapper + job_name, # job_name + ), + nprocs=num_inference_processes_per_machine, + join=True, + ) logger.info( f"--- Program finished, which took {time.time()-program_start_time:.2f} seconds" From beab394265d404190d092ea21d2ad7153f484dc7 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 28 Jan 2026 00:40:54 +0000 Subject: [PATCH 08/12] idk --- .../graph_store/heterogeneous_inference.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index d9818bca..e7d644de 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -104,7 +104,8 @@ def _inference_process( # In the case of the latter, the keys should be specified with format (SRC_NODE_TYPE, RELATION, DST_NODE_TYPE). # For the default example, we make a decision to keep the fanouts for all edge types the same, specifying the `fanout` with a `list[int]`. # To see an example of a 'fanout' with different behaviors per edge type, refer to `examples/link_prediction.configs/e2e_het_dblp_sup_task_config.yaml`. - print(f"Rank {local_rank} doing inference for node type {inference_node_type}") + inference_node_types = sorted(inference_node_types) # Sort the inference node types to ensure consistent ordering across processes + print(f"Rank {local_rank} doing inference for node types {inference_node_types}") fanout = inferencer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) @@ -161,6 +162,7 @@ def _inference_process( bq_table_path=output_bq_table_path, ) input_nodes = dataset.get_node_ids(node_type=inference_node_type) + logger.info(f"Rank {local_rank} has {[n.shape for n in input_nodes.values()]} input nodes for node type {inference_node_type}") sys.stdout.flush() data_loader = gigl.distributed.DistNeighborLoader( dataset=dataset, @@ -176,8 +178,9 @@ def _inference_process( # don't compete for memory during initialization, causing OOM process_start_gap_seconds=0, ) - print(f"Rank {local_rank} initialized the data loader for node type {inference_node_type}") + print(f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} initialized the data loader for node type {inference_node_type}") sys.stdout.flush() + torch.distributed.barrier() # Initialize a LinkPredictionGNN model and load parameters from # the saved model. model_state_dict = load_state_dict_from_uri( @@ -258,7 +261,7 @@ def _inference_process( cumulative_inference_time += time.time() - inference_start_time - if batch_idx > 0 and batch_idx % log_every_n_batch == 0: + if batch_idx == 0 or (batch_idx > 0 and batch_idx % log_every_n_batch == 0): # We don't see logs for graph store mode for whatever reason. # TOOD(#442): Revert this once the GCP issues are resolved. sys.stdout.flush() From b15264b59448cd6e11797f9786fc38acc3a85b2d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 28 Jan 2026 22:10:43 +0000 Subject: [PATCH 09/12] idk --- .../link_prediction/graph_store/heterogeneous_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index e7d644de..1b3cd73b 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -380,9 +380,9 @@ def _run_example_inference( for condensed_edge_type, edge_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map.items() } - inference_node_types = sorted( + inference_node_types = [sorted( gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_task_root_node_types() - ) + )[0]] inferencer_args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args) From 31aa3a72e8e1a536be6596902feb59174deb2cb9 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 29 Jan 2026 19:47:40 +0000 Subject: [PATCH 10/12] to dataclass thing --- .../graph_store/heterogeneous_inference.py | 616 ++++++++++-------- 1 file changed, 354 insertions(+), 262 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index 1b3cd73b..b3bb83a8 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -21,8 +21,9 @@ import argparse import gc -import sys import time +from dataclasses import dataclass +from typing import Optional, Union import torch import torch.distributed @@ -31,7 +32,35 @@ import gigl.distributed import gigl.distributed.utils -from gigl.common import GcsUri, UriFactory +from gigl.common import GcsUri, Uri, UriFactory +from gigl.common.data.export import EmbeddingExporter, load_embeddings_to_bigquery +from gigl.common.logger import Logger +from gigl.common.utils.gcs import GcsUtils +from gigl.distributed import DistDataset, build_dataset_from_task_config_uri +from gigl.nn import LinkPredictionGNN +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.common.utils.model import load_state_dict_from_uri +from gigl.src.inference.lib.assets import InferenceAssets +from gigl.utils.sampling import parse_fanout +import argparse +import gc +import os +import sys +import time +from collections.abc import MutableMapping +from dataclasses import dataclass +from typing import Union + +import torch +import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_homogeneous_model + +import gigl.distributed +import gigl.distributed.utils +from gigl.common import GcsUri, Uri, UriFactory from gigl.common.data.export import EmbeddingExporter, load_embeddings_to_bigquery from gigl.common.logger import Logger from gigl.common.utils.gcs import GcsUtils @@ -48,28 +77,77 @@ from gigl.src.inference.lib.assets import InferenceAssets from gigl.utils.sampling import parse_fanout + logger = Logger() +@dataclass(frozen=True) +class InferenceProcessArgs: + """ + Arguments for the heterogeneous inference process. + + Contains all configuration needed to run distributed inference for heterogeneous graph neural + networks, including distributed context, data configuration, model parameters, and inference + configuration. + + Attributes: + local_world_size (int): Number of inference processes spawned by each machine. + machine_rank (int): Rank of the current machine in the cluster. + machine_world_size (int): Total number of machines in the cluster. + master_ip_address (str): IP address of the master node for process group initialization. + master_default_process_group_port (int): Port for the default process group. + dataset (DistDataset): Loaded Distributed Dataset for inference. + inference_node_type (NodeType): Node type that embeddings should be generated for. + model_state_dict_uri (Uri): URI to load the trained model state dict from. + hid_dim (int): Hidden dimension of the model. + out_dim (int): Output dimension of the model. + node_type_to_feature_dim (dict[NodeType, int]): Mapping of node types to their feature + dimensions. + edge_type_to_feature_dim (dict[EdgeType, int]): Mapping of edge types to their feature + dimensions. + embedding_gcs_path (GcsUri): GCS path to write embeddings to. + inference_batch_size (int): Batch size to use for inference. + num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling, + where the ith item corresponds to the number of items to sample for the ith hop. + sampling_workers_per_inference_process (int): Number of sampling workers per inference + process. + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for + the channel during sampling (e.g., "4GB"). + log_every_n_batch (int): Frequency to log batch information during inference. + """ + + # Distributed context + local_world_size: int + machine_rank: int + machine_world_size: int + cluster_info: GraphStoreInfo + + # Data + inference_node_type: NodeType + mp_sharing_dict: MutableMapping[str, torch.Tensor] + + # Model + model_state_dict_uri: Uri + hid_dim: int + out_dim: int + node_type_to_feature_dim: dict[NodeType, int] + edge_type_to_feature_dim: dict[EdgeType, int] + + # Inference config + embedding_gcs_path: GcsUri + inference_batch_size: int + num_neighbors: Union[list[int], dict[EdgeType, list[int]]] + sampling_workers_per_inference_process: int + sampling_worker_shared_channel_size: str + log_every_n_batch: int + + @torch.no_grad() def _inference_process( # When spawning processes, each process will be assigned a rank ranging # from [0, num_processes). local_rank: int, - local_world_size: int, - cluster_info: GraphStoreInfo, - model_state_dict_uri: GcsUri, - inference_batch_size: int, - hid_dim: int, - out_dim: int, - inferencer_args: dict[str, str], - inference_node_type: NodeType, - node_type_to_feature_dim: dict[NodeType, int], - edge_type_to_feature_dim: dict[EdgeType, int], - mp_sharing_dict: dict[str, torch.Tensor], - inference_node_types: list[NodeType], - gbml_config_pb_wrapper: GbmlConfigPbWrapper, - job_name: str, + args: InferenceProcessArgs, ): """ This function is spawned by multiple processes per machine and is responsible for: @@ -79,53 +157,9 @@ def _inference_process( Args: local_rank (int): Process number on the current machine - local_world_size (int): Number of inference processes spawned by each machine - cluster_info (GraphStoreInfo): Cluster information - embedding_gcs_path (GcsUri): GCS path to load embeddings from - model_state_dict_uri (GcsUri): GCS path to load model from - inference_batch_size (int): Batch size to use for inference - hid_dim (int): Hidden dimension of the model - out_dim (int): Output dimension of the model - dataset (DistDataset): Loaded Distributed Dataset for inference - inferencer_args (dict[str, str]): Additional arguments for inferencer - inference_node_type (NodeType): Node Type that embeddings should be generated for in current inference process. This is used to - tag the embeddings written to GCS. - node_type_to_feature_dim (dict[NodeType, int]): Input node feature dimension per node type for the model - edge_type_to_feature_dim (dict[EdgeType, int]): Input edge feature dimension per edge type for the model - mp_sharing_dict (dict[str, torch.Tensor]): Shared memory dictionary for sharing data between processes - inference_node_types (list[NodeType]): List of node types to generate embeddings for - gbml_config_pb_wrapper (GbmlConfigPbWrapper): GBML config wrapper - job_name (str): Name of current job + args (InferenceProcessArgs): Dataclass containing all inference process arguments """ - # Parses the fanout as a string. - # For the heterogeneous case, the fanouts can be specified as a string of a list of integers, such as "[10, 10]", which will apply this fanout - # to each edge type in the graph, or as string of format dict[(tuple[str, str, str])), list[int]] which will specify fanouts per edge type. - # In the case of the latter, the keys should be specified with format (SRC_NODE_TYPE, RELATION, DST_NODE_TYPE). - # For the default example, we make a decision to keep the fanouts for all edge types the same, specifying the `fanout` with a `list[int]`. - # To see an example of a 'fanout' with different behaviors per edge type, refer to `examples/link_prediction.configs/e2e_het_dblp_sup_task_config.yaml`. - inference_node_types = sorted(inference_node_types) # Sort the inference node types to ensure consistent ordering across processes - print(f"Rank {local_rank} doing inference for node types {inference_node_types}") - fanout = inferencer_args.get("num_neighbors", "[10, 10]") - num_neighbors = parse_fanout(fanout) - - # While the ideal value for `sampling_workers_per_inference_process` has been identified to be between `2` and `4`, this may need some tuning depending on the - # pipeline. We default this value to `4` here for simplicity. A `sampling_workers_per_process` which is too small may not have enough parallelization for - # sampling, which would slow down inference, while a value which is too large may slow down each sampling process due to competing resources, which would also - # then slow down inference. - sampling_workers_per_inference_process: int = int( - inferencer_args.get("sampling_workers_per_inference_process", "4") - ) - - # This value represents the the shared-memory buffer size (bytes) allocated for the channel during sampling, and - # is the place to store pre-fetched data, so if it is too small then prefetching is limited, causing sampling slowdown. This parameter is a string - # with `{numeric_value}{storage_size}`, where storage size could be `MB`, `GB`, etc. We default this value to 4GB, - # but in production may need some tuning. - sampling_worker_shared_channel_size: str = inferencer_args.get( - "sampling_worker_shared_channel_size", "4GB" - ) - - log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) device = gigl.distributed.utils.get_available_device( local_process_rank=local_rank, ) # The device is automatically inferred based off the local process rank and the available devices @@ -133,197 +167,162 @@ def _inference_process( torch.cuda.set_device( device ) # Set the device for the current process. Without this, NCCL will fail when multiple GPUs are available. + + rank = args.machine_rank * args.local_world_size + local_rank + world_size = args.machine_world_size * args.local_world_size # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster. # If this is not done, the dataloader will not be able to sample from the graph store and will crash. - init_compute_process(local_rank, cluster_info) + init_compute_process(local_rank, args.cluster_info) dataset = RemoteDistDataset( - cluster_info, local_rank, mp_sharing_dict=mp_sharing_dict + args.cluster_info, local_rank, mp_sharing_dict=args.mp_sharing_dict ) - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - machine_rank = cluster_info.compute_node_rank logger.info( - f"Local rank {local_rank} in machine {machine_rank} has rank {rank}/{world_size} and using device {device} for inference" + f"Local rank {local_rank} in machine {args.machine_rank} has rank {rank}/{world_size} and using device {device} for inference" ) - for inference_node_type in inference_node_types: - logger.info(f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} starting inference for node type {inference_node_type}") - output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( - gbml_config_pb_wrapper, inference_node_type - ) - bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( - bq_table_path=output_bq_table_path - ) + # Get the node ids on the current machine for the current node type + input_nodes = dataset.get_node_ids(node_type=args.inference_node_type) + logger.info( + f"Rank {rank} got input nodes of shapes: {[f'{rank}: {node.shape}' for rank, node in input_nodes.items()]}" + ) + sys.stdout.flush() + data_loader = gigl.distributed.DistNeighborLoader( + dataset=dataset, + num_neighbors=args.num_neighbors, + # We must pass in a tuple of (node_type, node_ids_on_current_process) for heterogeneous input + input_nodes=(args.inference_node_type, input_nodes), + num_workers=args.sampling_workers_per_inference_process, + batch_size=args.inference_batch_size, + pin_memory_device=device, + worker_concurrency=args.sampling_workers_per_inference_process, + channel_size=args.sampling_worker_shared_channel_size, + # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders + # don't compete for memory during initialization, causing OOM + process_start_gap_seconds=0, + ) + sys.stdout.flush() + # Initialize a LinkPredictionGNN model and load parameters from + # the saved model. + model_state_dict = load_state_dict_from_uri( + load_from_uri=args.model_state_dict_uri, device=device + ) + model: LinkPredictionGNN = init_example_gigl_heterogeneous_model( + node_type_to_feature_dim=args.node_type_to_feature_dim, + edge_type_to_feature_dim=args.edge_type_to_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + state_dict=model_state_dict, + ) - # We write embeddings to a temporary GCS path during the inference loop, since writing directly to bigquery for each embedding is slow. - # After inference has finished, we then load all embeddings to bigquery from GCS. - embedding_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( - applied_task_identifier=AppliedTaskIdentifier(job_name), - bq_table_path=output_bq_table_path, - ) - input_nodes = dataset.get_node_ids(node_type=inference_node_type) - logger.info(f"Rank {local_rank} has {[n.shape for n in input_nodes.values()]} input nodes for node type {inference_node_type}") - sys.stdout.flush() - data_loader = gigl.distributed.DistNeighborLoader( - dataset=dataset, - num_neighbors=num_neighbors, - # We must pass in a tuple of (node_type, node_ids_on_current_process) for heterogeneous input - input_nodes=(inference_node_type, input_nodes), - num_workers=sampling_workers_per_inference_process, - batch_size=inference_batch_size, - pin_memory_device=device, - worker_concurrency=sampling_workers_per_inference_process, - channel_size=sampling_worker_shared_channel_size, - # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders - # don't compete for memory during initialization, causing OOM - process_start_gap_seconds=0, - ) - print(f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} initialized the data loader for node type {inference_node_type}") - sys.stdout.flush() - torch.distributed.barrier() - # Initialize a LinkPredictionGNN model and load parameters from - # the saved model. - model_state_dict = load_state_dict_from_uri( - load_from_uri=model_state_dict_uri, device=device - ) - model: LinkPredictionGNN = init_example_gigl_heterogeneous_model( - node_type_to_feature_dim=node_type_to_feature_dim, - edge_type_to_feature_dim=edge_type_to_feature_dim, - hid_dim=hid_dim, - out_dim=out_dim, - device=device, - state_dict=model_state_dict, - ) + # Set the model to evaluation mode for inference. + model.eval() - # Set the model to evaluation mode for inference. - model.eval() + logger.info(f"Model initialized on device {device}") + + embedding_filename = ( + f"machine_{args.machine_rank}_local_process_number_{local_rank}" + ) - logger.info(f"Model initialized on device {device}") + # Get temporary GCS folder to write outputs of inference to. GiGL orchestration automatic cleans this, but + # if running manually, you will need to clean this directory so that retries don't end up with stale files. + gcs_utils = GcsUtils() + gcs_base_uri = GcsUri.join(args.embedding_gcs_path, embedding_filename) + num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(gcs_base_uri) + if num_files_at_gcs_path > 0: + logger.warning( + f"{num_files_at_gcs_path} files already detected at base gcs path. Cleaning up files at path ... " + ) + gcs_utils.delete_files_in_bucket_dir(gcs_base_uri) - embedding_filename = f"machine_{machine_rank}_local_process_number_{local_rank}" + # GiGL class for exporting embeddings to GCS. This is achieved by writing ids and embeddings to an in-memory buffer which gets + # flushed to GCS. Setting the min_shard_size_threshold_bytes field of this class sets the frequency of flushing to GCS, and defaults + # to only flushing when flush_records() is called explicitly or after exiting via a context manager. + exporter = EmbeddingExporter(export_dir=gcs_base_uri) - # Get temporary GCS folder to write outputs of inference to. GiGL orchestration automatic cleans this, but - # if running manually, you will need to clean this directory so that retries don't end up with stale files. - gcs_utils = GcsUtils() - gcs_base_uri = GcsUri.join(embedding_output_gcs_folder, embedding_filename) - num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(gcs_base_uri) - if num_files_at_gcs_path > 0: - logger.warning( - f"{num_files_at_gcs_path} files already detected at base gcs path. Cleaning up files at path ... " - ) - gcs_utils.delete_files_in_bucket_dir(gcs_base_uri) + # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph + # sampling may fail. + sys.stdout.flush() + torch.distributed.barrier() - # GiGL class for exporting embeddings to GCS. This is achieved by writing ids and embeddings to an in-memory buffer which gets - # flushed to GCS. Setting the min_shard_size_threshold_bytes field of this class sets the frequency of flushing to GCS, and defaults - # to only flushing when flush_records() is called explicitly or after exiting via a context manager. - exporter = EmbeddingExporter(export_dir=gcs_base_uri) + t = time.time() + data_loading_start_time = time.time() + inference_start_time = time.time() + cumulative_data_loading_time = 0.0 + cumulative_inference_time = 0.0 + sys.stdout.flush() - # We don't see logs for graph store mode for whatever reason. - # TOOD(#442): Revert this once the GCP issues are resolved. - sys.stdout.flush() - # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph - # sampling may fail. + # Begin inference loop - torch.distributed.barrier() + # Iterating through the dataloader yields a `torch_geometric.data.Data` type + for batch_idx, data in enumerate(data_loader): + cumulative_data_loading_time += time.time() - data_loading_start_time - t = time.time() - data_loading_start_time = time.time() inference_start_time = time.time() - cumulative_data_loading_time = 0.0 - cumulative_inference_time = 0.0 - - # Begin inference loop - # Iterating through the dataloader yields a `torch_geometric.data.Data` type - for batch_idx, data in enumerate(data_loader): - cumulative_data_loading_time += time.time() - data_loading_start_time + # These arguments to forward are specific to the GiGL heterogeneous LinkPredictionGNN model. + # If just using a nn.Module, you can just use output = model(data) + output = model( + data=data, output_node_types=[args.inference_node_type], device=device + )[args.inference_node_type] - inference_start_time = time.time() + # The anchor node IDs are contained inside of the .batch field of the data + node_ids = data[args.inference_node_type].batch.cpu() - # These arguments to forward are specific to the GiGL heterogeneous LinkPredictionGNN model. - # If just using a nn.Module, you can just use output = model(data) - output = model( - data=data, output_node_types=[inference_node_type], device=device - )[inference_node_type] + # Only the first `batch_size` rows of the node embeddings contain the embeddings of the anchor nodes + node_embeddings = output[: data[args.inference_node_type].batch_size].cpu() - # The anchor node IDs are contained inside of the .batch field of the data - node_ids = data[inference_node_type].batch.cpu() + # We add ids and embeddings to the in-memory buffer + exporter.add_embedding( + id_batch=node_ids, + embedding_batch=node_embeddings, + embedding_type=str(args.inference_node_type), + ) - # Only the first `batch_size` rows of the node embeddings contain the embeddings of the anchor nodes - node_embeddings = output[: data[inference_node_type].batch_size].cpu() + cumulative_inference_time += time.time() - inference_start_time - # We add ids and embeddings to the in-memory buffer - exporter.add_embedding( - id_batch=node_ids, - embedding_batch=node_embeddings, - embedding_type=str(inference_node_type), + if batch_idx == 0 or (batch_idx > 0 and batch_idx % args.log_every_n_batch == 0): + logger.info( + f"Rank {rank} processed {batch_idx} batches for node type {args.inference_node_type}. " + f"{args.log_every_n_batch} batches took {time.time() - t:.2f} seconds for node type {args.inference_node_type}. " + f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds." + f"and model inference took {cumulative_inference_time:.2f} seconds." ) + t = time.time() + cumulative_data_loading_time = 0 + cumulative_inference_time = 0 + sys.stdout.flush() - cumulative_inference_time += time.time() - inference_start_time - - if batch_idx == 0 or (batch_idx > 0 and batch_idx % log_every_n_batch == 0): - # We don't see logs for graph store mode for whatever reason. - # TOOD(#442): Revert this once the GCP issues are resolved. - sys.stdout.flush() - logger.info( - f"Rank {rank} processed {batch_idx} batches for node type {inference_node_type}. " - f"{log_every_n_batch} batches took {time.time() - t:.2f} seconds for node type {inference_node_type}. " - f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds." - f"and model inference took {cumulative_inference_time:.2f} seconds." - ) - t = time.time() - cumulative_data_loading_time = 0 - cumulative_inference_time = 0 - - data_loading_start_time = time.time() - - logger.info( - f"--- Rank {rank} finished inference for node type {inference_node_type}." - ) + data_loading_start_time = time.time() - write_embedding_start_time = time.time() - # Flushes all remaining embeddings to GCS - exporter.flush_records() + logger.info( + f"--- Rank {rank} finished inference for node type {args.inference_node_type}." + ) - logger.info( - f"--- Rank {rank} finished writing embeddings to GCS for node type {inference_node_type}, which took {time.time()-write_embedding_start_time:.2f} seconds" - ) - # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. - if cluster_info.compute_node_rank == 0: - logger.info( - f"--- Machine 0 triggers loading embeddings from GCS to BigQuery for node type {inference_node_type}" - ) - # If we are on the last inference process, we should wait for this last write process to complete. Otherwise, we should - # load embeddings to bigquery in the background so that we are not blocking the start of the next inference process - should_run_async = local_rank != local_world_size - 1 + write_embedding_start_time = time.time() + # Flushes all remaining embeddings to GCS + exporter.flush_records() - # The `load_embeddings_to_bigquery` API returns a BigQuery LoadJob object - # representing the load operation, which allows user to monitor and retrieve - # details about the job status and result. - _ = load_embeddings_to_bigquery( - gcs_folder=embedding_output_gcs_folder, - project_id=bq_project_id, - dataset_id=bq_dataset_id, - table_id=bq_table_name, - should_run_async=should_run_async, - ) + logger.info( + f"--- Rank {rank} finished writing embeddings to GCS for node type {args.inference_node_type}, which took {time.time()-write_embedding_start_time:.2f} seconds" + ) + # We first call barrier to ensure that all machines and processes have finished inference. + # Only once all machines have finished inference is it safe to shutdown the data loader. + # Otherwise, processes which are still sampling *will* fail as the loaders they are trying to communicatate with will be shutdown. + # We then call `gc.collect()` to cleanup the memory used by the data_loader on the current machine. - # We first call barrier to ensure that all machines and processes have finished inference. - # Only once all machines have finished inference is it safe to shutdown the data loader. - # Otherwise, processes which are still sampling *will* fail as the loaders they are trying to communicatate with will be shutdown. - # We then call `gc.collect()` to cleanup the memory used by the data_loader on the current machine. + torch.distributed.barrier() - torch.distributed.barrier() + data_loader.shutdown() + gc.collect() - data_loader.shutdown() - gc.collect() + logger.info( + f"--- All machines local rank {local_rank} finished inference for node type {args.inference_node_type}. Deleted data loader" + ) - logger.info( - f"--- All machines local rank {local_rank} finished inference for node type {inference_node_type}. Deleted data loader" - ) + sys.stdout.flush() - sys.stdout.flush() def _run_example_inference( job_name: str, @@ -351,9 +350,16 @@ def _run_example_inference( logger.info( f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" ) + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + cluster_info = get_graph_store_info() logger.info(f"Cluster info: {cluster_info}") torch.distributed.destroy_process_group() + logger.info( + f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" + ) # Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path, and additional inference args gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( @@ -380,9 +386,9 @@ def _run_example_inference( for condensed_edge_type, edge_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map.items() } - inference_node_types = [sorted( + inference_node_types = sorted( gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_task_root_node_types() - )[0]] + ) inferencer_args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args) @@ -409,58 +415,144 @@ def _run_example_inference( raise ValueError( f"Number of inference processes per machine ({num_inference_processes_per_machine}) must not be more than the number of GPUs: ({torch.cuda.device_count()})" ) + sys.stdout.flush() ## Inference Start - + sys.stdout.flush() inference_start_time = time.time() - mp_sharing_dict = mp.Manager().dict() - sys.stdout.flush() - # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. - mp.spawn( - fn=_inference_process, - args=( - num_inference_processes_per_machine, # local_world_size - cluster_info, # cluster_info - model_uri, # model_state_dict_uri - inference_batch_size, # inference_batch_size - hid_dim, # hid_dim - out_dim, # out_dim - inferencer_args, # inferencer_args - inference_node_types, # inference_node_types - node_type_to_feature_dim, # node_type_to_feature_dim - edge_type_to_feature_dim, # edge_type_to_feature_dim - mp_sharing_dict, # mp_sharing_dict - inference_node_types, # inference_node_types - gbml_config_pb_wrapper, # gbml_config_pb_wrapper - job_name, # job_name - ), - nprocs=num_inference_processes_per_machine, - join=True, - ) + for process_num, inference_node_type in enumerate(inference_node_types): + logger.info( + f"Starting inference process for node type {inference_node_type} ..." + ) + output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( + gbml_config_pb_wrapper, inference_node_type + ) + + bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( + bq_table_path=output_bq_table_path + ) + + # We write embeddings to a temporary GCS path during the inference loop, since writing directly to bigquery for each embedding is slow. + # After inference has finished, we then load all embeddings to bigquery from GCS. + embedding_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( + applied_task_identifier=AppliedTaskIdentifier(job_name), + bq_table_path=output_bq_table_path, + ) + + # Parses the fanout as a string. For the heterogeneous case, the fanouts can be specified + # as a string of a list of integers, such as "[10, 10]", which will apply this fanout to + # each edge type in the graph, or as string of format dict[(tuple[str, str, str])), + # list[int]] which will specify fanouts per edge type. In the case of the latter, the keys + # should be specified with format (SRC_NODE_TYPE, RELATION, DST_NODE_TYPE). For the default + # example, we make a decision to keep the fanouts for all edge types the same, specifying + # the `fanout` with a `list[int]`. To see an example of a 'fanout' with different behaviors + # per edge type, refer to `examples/link_prediction.configs/e2e_het_dblp_sup_task_config.yaml`. + num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) + + # While the ideal value for `sampling_workers_per_inference_process` has been identified to + # be between `2` and `4`, this may need some tuning depending on the pipeline. We default + # this value to `4` here for simplicity. A `sampling_workers_per_process` which is too + # small may not have enough parallelization for sampling, which would slow down inference, + # while a value which is too large may slow down each sampling process due to competing + # resources, which would also then slow down inference. + sampling_workers_per_inference_process = int( + inferencer_args.get("sampling_workers_per_inference_process", "4") + ) + + # This value represents the shared-memory buffer size (bytes) allocated for the channel + # during sampling, and is the place to store pre-fetched data, so if it is too small then + # prefetching is limited, causing sampling slowdown. This parameter is a string with + # `{numeric_value}{storage_size}`, where storage size could be `MB`, `GB`, etc. We default + # this value to 4GB, but in production may need some tuning. + sampling_worker_shared_channel_size = inferencer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) + + # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. + inference_args = InferenceProcessArgs( + local_world_size=num_inference_processes_per_machine, + machine_rank=cluster_info.compute_node_rank, + machine_world_size=cluster_info.num_compute_nodes, + cluster_info=cluster_info, + inference_node_type=inference_node_type, + mp_sharing_dict=torch.multiprocessing.Manager().dict(), + model_state_dict_uri=model_uri, + hid_dim=hid_dim, + out_dim=out_dim, + node_type_to_feature_dim=node_type_to_feature_dim, + edge_type_to_feature_dim=edge_type_to_feature_dim, + embedding_gcs_path=embedding_output_gcs_folder, + inference_batch_size=inference_batch_size, + num_neighbors=num_neighbors, + sampling_workers_per_inference_process=sampling_workers_per_inference_process, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + log_every_n_batch=log_every_n_batch, + ) + logger.info(f"Rank {cluster_info.compute_node_rank} started inference process for node type {inference_node_type} with {num_inference_processes_per_machine} processes\nargs: {inference_args}") + sys.stdout.flush() + + mp.spawn( + fn=_inference_process, + args=(inference_args,), + nprocs=num_inference_processes_per_machine, + join=True, + ) + + logger.info( + f"--- Inference finished on rank {cluster_info.compute_node_rank} for node type {inference_node_type}, which took {time.time()-inference_start_time:.2f} seconds" + ) + sys.stdout.flush() + + # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. + if cluster_info.compute_node_rank == 0: + logger.info( + f"--- Machine 0 triggers loading embeddings from GCS to BigQuery for node type {inference_node_type}" + ) + # If we are on the last inference process, we should wait for this last write process to complete. Otherwise, we should + # load embeddings to bigquery in the background so that we are not blocking the start of the next inference process + should_run_async = process_num != len(inference_node_types) - 1 + # The `load_embeddings_to_bigquery` API returns a BigQuery LoadJob object + # representing the load operation, which allows user to monitor and retrieve + # details about the job status and result. + _ = load_embeddings_to_bigquery( + gcs_folder=embedding_output_gcs_folder, + project_id=bq_project_id, + dataset_id=bq_dataset_id, + table_id=bq_table_name, + should_run_async=should_run_async, + ) + sys.stdout.flush() logger.info( f"--- Program finished, which took {time.time()-program_start_time:.2f} seconds" ) if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Arguments for distributed model inference on VertexAI" - ) - parser.add_argument( - "--job_name", - type=str, - help="Inference job name", - ) - parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + try: + parser = argparse.ArgumentParser( + description="Arguments for distributed model inference on VertexAI" + ) + parser.add_argument( + "--job_name", + type=str, + help="Inference job name", + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") - # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference - args, unused_args = parser.parse_known_args() - logger.info(f"Unused arguments: {unused_args}") + # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") - # We only need `job_name` and `task_config_uri` for running inference - _run_example_inference( - job_name=args.job_name, - task_config_uri=args.task_config_uri, - ) + # We only need `job_name` and `task_config_uri` for running inference + _run_example_inference( + job_name=args.job_name, + task_config_uri=args.task_config_uri, + ) + except Exception as e: + sys.stderr.write(f"Error: {e}\n") + sys.stderr.flush() + raise e From 5b3d699de9a30581642414e51f54be20528744ba Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 29 Jan 2026 22:44:08 +0000 Subject: [PATCH 11/12] mulitple rpc setups --- gigl/distributed/graph_store/storage_main.py | 40 +++++++++++-------- gigl/distributed/graph_store/storage_utils.py | 4 +- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/gigl/distributed/graph_store/storage_main.py b/gigl/distributed/graph_store/storage_main.py index 09a0d79c..36aca219 100644 --- a/gigl/distributed/graph_store/storage_main.py +++ b/gigl/distributed/graph_store/storage_main.py @@ -18,6 +18,7 @@ from gigl.distributed.utils import get_graph_store_info from gigl.distributed.utils.networking import get_free_ports_from_master_node from gigl.env.distributed import GraphStoreInfo +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper logger = Logger() @@ -102,27 +103,34 @@ def storage_node_process( is_inference=is_inference, _tfrecord_uri_pattern=tf_record_uri_pattern, ) + task_config = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=task_config_uri + ) + inference_node_types = sorted(task_config.task_metadata_pb_wrapper.get_task_root_node_types()) + logger.info(f"Inference node types: {inference_node_types}") torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] torch.distributed.destroy_process_group() server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") # TODO(kmonte): Enable more than one server process per machine - for i in range(1): - server_process = mp_context.Process( - target=_run_storage_process, - args=( - storage_rank + i, # storage_rank - cluster_info, # cluster_info - dataset, # dataset - torch_process_port, # torch_process_port - storage_world_backend, # storage_world_backend - ), - ) - server_processes.append(server_process) - for server_process in server_processes: - server_process.start() - for server_process in server_processes: - server_process.join() + for i, inference_node_type in enumerate(inference_node_types): + logger.info(f"Starting storage node for inference node type {inference_node_type} (storage process group {i} / {len(inference_node_types)})") + for i in range(1): + server_process = mp_context.Process( + target=_run_storage_process, + args=( + storage_rank + i, # storage_rank + cluster_info, # cluster_info + dataset, # dataset + torch_process_port, # torch_process_port + storage_world_backend, # storage_world_backend + ), + ) + server_processes.append(server_process) + for server_process in server_processes: + server_process.start() + for server_process in server_processes: + server_process.join() if __name__ == "__main__": diff --git a/gigl/distributed/graph_store/storage_utils.py b/gigl/distributed/graph_store/storage_utils.py index e90209e7..074796c7 100644 --- a/gigl/distributed/graph_store/storage_utils.py +++ b/gigl/distributed/graph_store/storage_utils.py @@ -147,7 +147,9 @@ def get_node_ids_for_rank( raise ValueError( f"Node ids must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(_dataset.node_ids)}" ) - return shard_nodes_by_process(nodes, rank, world_size) + nodes = shard_nodes_by_process(nodes, rank, world_size) + logger.info(f"Got {nodes.shape[0]} nodes for rank {rank} / {world_size} with node type {node_type}") + return nodes def get_edge_types() -> Optional[list[EdgeType]]: From d442a67f083f506df3dc303d3f11602d9ea8f0e6 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 30 Jan 2026 19:01:09 +0000 Subject: [PATCH 12/12] works??? --- .../graph_store/heterogeneous_inference.py | 38 ++++++++++++------- gigl/distributed/graph_store/storage_main.py | 11 +++--- testing/e2e_tests/e2e_tests.yaml | 30 +++++++-------- 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_inference.py b/examples/link_prediction/graph_store/heterogeneous_inference.py index b3bb83a8..8102bbd4 100644 --- a/examples/link_prediction/graph_store/heterogeneous_inference.py +++ b/examples/link_prediction/graph_store/heterogeneous_inference.py @@ -64,7 +64,7 @@ from gigl.common.data.export import EmbeddingExporter, load_embeddings_to_bigquery from gigl.common.logger import Logger from gigl.common.utils.gcs import GcsUtils -from gigl.distributed.graph_store.compute import init_compute_process +from gigl.distributed.graph_store.compute import init_compute_process, shutdown_compute_proccess from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils import get_graph_store_info from gigl.env.distributed import GraphStoreInfo @@ -81,6 +81,12 @@ logger = Logger() +def flush(): + sys.stdout.write("\n") + sys.stdout.flush() + sys.stderr.write("\n") + sys.stderr.flush() + @dataclass(frozen=True) class InferenceProcessArgs: """ @@ -172,6 +178,8 @@ def _inference_process( world_size = args.machine_world_size * args.local_world_size # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster. # If this is not done, the dataloader will not be able to sample from the graph store and will crash. + logger.info(f"Initializing compute process for rank {local_rank} in machine {args.machine_rank} with cluster info {args.cluster_info} for inference node type {args.inference_node_type}") + flush() init_compute_process(local_rank, args.cluster_info) dataset = RemoteDistDataset( args.cluster_info, local_rank, mp_sharing_dict=args.mp_sharing_dict @@ -185,7 +193,7 @@ def _inference_process( logger.info( f"Rank {rank} got input nodes of shapes: {[f'{rank}: {node.shape}' for rank, node in input_nodes.items()]}" ) - sys.stdout.flush() + flush() data_loader = gigl.distributed.DistNeighborLoader( dataset=dataset, num_neighbors=args.num_neighbors, @@ -200,7 +208,7 @@ def _inference_process( # don't compete for memory during initialization, causing OOM process_start_gap_seconds=0, ) - sys.stdout.flush() + flush() # Initialize a LinkPredictionGNN model and load parameters from # the saved model. model_state_dict = load_state_dict_from_uri( @@ -242,7 +250,7 @@ def _inference_process( # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph # sampling may fail. - sys.stdout.flush() + flush() torch.distributed.barrier() t = time.time() @@ -250,7 +258,7 @@ def _inference_process( inference_start_time = time.time() cumulative_data_loading_time = 0.0 cumulative_inference_time = 0.0 - sys.stdout.flush() + flush() # Begin inference loop @@ -291,7 +299,7 @@ def _inference_process( t = time.time() cumulative_data_loading_time = 0 cumulative_inference_time = 0 - sys.stdout.flush() + flush() data_loading_start_time = time.time() @@ -315,13 +323,14 @@ def _inference_process( torch.distributed.barrier() data_loader.shutdown() + shutdown_compute_proccess() gc.collect() logger.info( - f"--- All machines local rank {local_rank} finished inference for node type {args.inference_node_type}. Deleted data loader" + f"--- All machines local rank {local_rank} finished inference for node type {args.inference_node_type}. Deleted data loader and shutdown compute process" ) - sys.stdout.flush() + flush() def _run_example_inference( @@ -415,10 +424,10 @@ def _run_example_inference( raise ValueError( f"Number of inference processes per machine ({num_inference_processes_per_machine}) must not be more than the number of GPUs: ({torch.cuda.device_count()})" ) - sys.stdout.flush() + flush() ## Inference Start - sys.stdout.flush() + flush() inference_start_time = time.time() for process_num, inference_node_type in enumerate(inference_node_types): @@ -492,7 +501,7 @@ def _run_example_inference( log_every_n_batch=log_every_n_batch, ) logger.info(f"Rank {cluster_info.compute_node_rank} started inference process for node type {inference_node_type} with {num_inference_processes_per_machine} processes\nargs: {inference_args}") - sys.stdout.flush() + flush() mp.spawn( fn=_inference_process, @@ -504,7 +513,7 @@ def _run_example_inference( logger.info( f"--- Inference finished on rank {cluster_info.compute_node_rank} for node type {inference_node_type}, which took {time.time()-inference_start_time:.2f} seconds" ) - sys.stdout.flush() + flush() # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. if cluster_info.compute_node_rank == 0: @@ -525,7 +534,7 @@ def _run_example_inference( table_id=bq_table_name, should_run_async=should_run_async, ) - sys.stdout.flush() + flush() logger.info( f"--- Program finished, which took {time.time()-program_start_time:.2f} seconds" ) @@ -545,7 +554,8 @@ def _run_example_inference( # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference args, unused_args = parser.parse_known_args() - logger.info(f"Unused arguments: {unused_args}") + logger.info(f"Args: {args}, Unused arguments: {unused_args}") + flush() # We only need `job_name` and `task_config_uri` for running inference _run_example_inference( diff --git a/gigl/distributed/graph_store/storage_main.py b/gigl/distributed/graph_store/storage_main.py index 36aca219..fc160f2e 100644 --- a/gigl/distributed/graph_store/storage_main.py +++ b/gigl/distributed/graph_store/storage_main.py @@ -110,11 +110,11 @@ def storage_node_process( logger.info(f"Inference node types: {inference_node_types}") torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] torch.distributed.destroy_process_group() - server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") # TODO(kmonte): Enable more than one server process per machine for i, inference_node_type in enumerate(inference_node_types): logger.info(f"Starting storage node for inference node type {inference_node_type} (storage process group {i} / {len(inference_node_types)})") + server_processes = [] for i in range(1): server_process = mp_context.Process( target=_run_storage_process, @@ -127,10 +127,11 @@ def storage_node_process( ), ) server_processes.append(server_process) - for server_process in server_processes: - server_process.start() - for server_process in server_processes: - server_process.join() + for server_process in server_processes: + server_process.start() + for server_process in server_processes: + server_process.join() + logger.info(f"All server processes for inference node type {inference_node_type} joined") if __name__ == "__main__": diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index 0f3691e8..c1b9a7c7 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -1,21 +1,21 @@ # Combined e2e test configurations for GiGL # This file contains all the test specifications that can be run via the e2e test script tests: - cora_nalp_test: - task_config_uri: "gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_snc_test: - task_config_uri: "gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_udl_test: - task_config_uri: "gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - dblp_nalp_test: - task_config_uri: "gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - hom_cora_sup_test: - task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + # cora_nalp_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_snc_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_udl_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # dblp_nalp_test: + # task_config_uri: "gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # hom_cora_sup_test: + # task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" het_dblp_sup_test: task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}"