BaseModel

BaseModel

class models.BaseModel.BaseModel(*args: Any, **kwargs: Any)

Base class for all models.

__init__(config: utils.util.Config)
Parameters
  • config – the configuration object intialized by utils.manager.Manager.setup()

  • name – the name of the model

metrics

the metric dictionary containing metric_type: metric_value pairs

Type

dict

config
name

the name of the model

Type

str

index_dir

the folder to save index e.g. utils.index.FaissIndex and utils.index.InvertedVectorIndex

Type

str

collection_dir

the folder to save json collections returned by e.g. utils.index.AnseriniBM25Index.fit()

Type

str

encode_dir

the folder to save text encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_text()

Type

str

query_dir

the folder to save query encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_query()

Type

str

retrieval_result_path

the path of the final retrieval result file returned by models.BaseModel.BaseModel.retrieve()

Type

str

_rank

the current process ID

Type

int

_world_size

the number of all processes

Type

int

_distributed

if distributed training/evaluating is enabled

Type

bool

logger

the logger

Type

MasterLoger

_move_to_device(data, exclude_keys=['text_idx', 'query_idx'])

Move data to device.

Parameters

exclude_keys – variables that should be kept unchanged

_l2_distance(x1: torch.Tensor, x2: torch.Tensor) torch.Tensor

Compute l2 similarity.

Parameters
  • x1 – tensor of [B, D]

  • x2 – tensor of [B, D]

_cos_sim(x1: torch.Tensor, x2: torch.Tensor, temperature: float = 0.1) torch.Tensor

Compute cosine similarity.

Parameters
  • x1 – tensor of [B, D]

  • x2 – tensor of [B, D]

  • temperature – scale the similarity scores by dividing temperature

_compute_teacher_score(x)

Compute teacher score in knowledge distillation; return None if training in contrastive mode.

_compute_loss(score: torch.Tensor, label: torch.Tensor, teacher_score: Optional[torch.Tensor] = None)

A general method to compute loss (contrastive cross-entropy or distillation)

Parameters
  • score – tensor of [B, *]

  • label – tensor of [B, *]

  • x – the input data

create_optimizer() torch.optim.Optimizer

Create optimizer; Subclass may override this function to create custom optimizers. Return None to use the default optimizer created by Trainer.

Returns

optimizer

_gather_objects(local_object: object) list[object]

Gather common python objects across processes.

Note

This function implicitly consumes GPU.

Parameters

local_object – python object to collect

_gather_tensors(local_tensor: torch.Tensor) torch.Tensor

Gather tensors from all gpus on each process.

Parameters

local_tensor – the tensor that needs to be gathered

Returns

concatenation of local_tensor in each process

save_to_mmp(path: str, shape: tuple, dtype: numpy.dtype, loader: torch.utils.data.DataLoader, obj: numpy.ndarray, batch_size: int = 1000)
  1. Create a np.memmap file of shape with dtype;

  2. Create lock;

  3. Save the obj to the offset utils.util.Sequential_Sampler.start;

  4. Release lock.

Parameters
  • path – the memmap file path

  • shape – the shape of the memmap file to be created

  • dtype

  • loader – the dataloader for the data

  • obj – the array to be stored

  • batch_size – saving in batch

gather_retrieval_result(retrieval_result: Union[dict[int, list[int]], dict[int, list[tuple[int, float]]]], hits: Optional[int] = None, retrieval_result_path: Optional[str] = None) Union[dict[int, list[int]], dict[int, list[tuple[int, float]]]]
  1. Gather retrieval_result across processes;

  2. Returning the reordered result cut off to top k;

  3. Create a lock;

  4. Save the result at models.BaseModel.BaseModel.retrieval_result_path.

  5. Release the lock.

Parameters

retrieval_result – each tuple is a document id-score pair

Returns

processed retrieval result

init_verifier(loaders: dict[str, torch.utils.data.DataLoader], load_all_verifier: bool = False)

Initialize post verifier defined in :pyobj:utils.index.VERIFIER_MAP.

Parameters
  • loaders

  • load_all_verifier – if True, load all the verifier embeddings/codes

encode(loaders)

Shotcut for encoding both text and query.

index(loaders: dict[str, torch.utils.data.DataLoader])

The index method. Subclass should override this function.

retrieve(loaders: dict[str, torch.utils.data.DataLoader])

The retrieve method. Subclass should override this function.

log_result(**kwargs)

Save the model metrics and configurations in performance.log.

save(checkpoint: Optional[Union[str, int]] = None)

Save the model at checkpoint.

load()

Load the model with current config from config.load_ckpt.

step_end_callback(loaders, state)

Callback at the end of each training step.

BaseSparseModel

class models.BaseModel.BaseSparseModel(*args: Any, **kwargs: Any)

Base class for all models that rely on token weights to rank documents.

__init__(config: utils.util.Config)
Parameters
  • config – the configuration object intialized by utils.manager.Manager.setup()

  • name – the name of the model

metrics

the metric dictionary containing metric_type: metric_value pairs

Type

dict

config
name

the name of the model

Type

str

index_dir

the folder to save index e.g. utils.index.FaissIndex and utils.index.InvertedVectorIndex

Type

str

collection_dir

the folder to save json collections returned by e.g. utils.index.AnseriniBM25Index.fit()

Type

str

encode_dir

the folder to save text encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_text()

Type

str

query_dir

the folder to save query encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_query()

Type

str

retrieval_result_path

the path of the final retrieval result file returned by models.BaseModel.BaseModel.retrieve()

Type

str

_rank

the current process ID

Type

int

_world_size

the number of all processes

Type

int

_distributed

if distributed training/evaluating is enabled

Type

bool

logger

the logger

Type

MasterLoger

_compute_overlap(query_token_id: torch.Tensor, text_token_id: torch.Tensor) torch.Tensor

Compute overlapping mask between the query tokens and positive sequence tokens across batches.

Parameters
  • query_token_id – [B1, LQ]

  • text_token_id – [B2, LS]

Returns

[B, LQ, B, LS] if cross_batch, else [B, LQ, LS]

Return type

overlapping_mask

_gate_text(text_token_weights: numpy.ndarray, k: Optional[int] = None)

Gate the text token weights so that only the top config.query_gate_k tokens are valid. Keep the text_token_ids because we will use it to construct the entire inverted lists.

Parameters

query_embeddings – [N, L, 1]

encode_text_step(x)

One step in encode_text.

Parameters

x – a data record.

Returns

the text token id for indexing, array of [B, L] the text token embedding for indexing, array of [B, L, D]

encode_query_step(x)

One step in encode_text.

Parameters

x – a data record.

Returns

the query token id for searching, array of [B, L] the query token embedding for indexing, array of [B, L, D]

inverted_index(loader_text: torch.utils.data.DataLoader)

Construct utils.index.BaseInvertedIndex.

anserini_index(loader_text: torch.utils.data.DataLoader)

Construct utils.index.BaseAnseriniIndex.

index(**kwargs)

The index method. Subclass should override this function.

retrieve(**kwargs)

The retrieve method. Subclass should override this function.

generate_code(loaders: dict[str, torch.utils.data.DataLoader])

Generate codes from the cache embedding files.

BaseDenseModel

class models.BaseModel.BaseDenseModel(*args: Any, **kwargs: Any)

Base class for all models that rely on sequence embeddings to rank documents.

__init__(config)
Parameters
  • config – the configuration object intialized by utils.manager.Manager.setup()

  • name – the name of the model

metrics

the metric dictionary containing metric_type: metric_value pairs

Type

dict

config
name

the name of the model

Type

str

index_dir

the folder to save index e.g. utils.index.FaissIndex and utils.index.InvertedVectorIndex

Type

str

collection_dir

the folder to save json collections returned by e.g. utils.index.AnseriniBM25Index.fit()

Type

str

encode_dir

the folder to save text encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_text()

Type

str

query_dir

the folder to save query encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_query()

Type

str

retrieval_result_path

the path of the final retrieval result file returned by models.BaseModel.BaseModel.retrieve()

Type

str

_rank

the current process ID

Type

int

_world_size

the number of all processes

Type

int

_distributed

if distributed training/evaluating is enabled

Type

bool

logger

the logger

Type

MasterLoger

faiss_index(loader_text: torch.utils.data.DataLoader)

Construct utils.index.FaissIndex

index(**kwargs)

The index method. Subclass should override this function.

retrieve(**kwargs)

The retrieve method. Subclass should override this function.

cluster(loaders: dict[str, torch.utils.data.DataLoader])

Perform clusering over cached embeddings.

generate_code(loaders: dict[str, torch.utils.data.DataLoader])

Generate codes from the cached clusering assignments.

BaseGenerativeModel

class models.BaseModel.BaseGenerativeModel(*args: Any, **kwargs: Any)

Base class for generative models e.g. DSI, WebUltron.

__init__(config: utils.util.Config)
Parameters
  • config – the configuration object intialized by utils.manager.Manager.setup()

  • name – the name of the model

metrics

the metric dictionary containing metric_type: metric_value pairs

Type

dict

config
name

the name of the model

Type

str

index_dir

the folder to save index e.g. utils.index.FaissIndex and utils.index.InvertedVectorIndex

Type

str

collection_dir

the folder to save json collections returned by e.g. utils.index.AnseriniBM25Index.fit()

Type

str

encode_dir

the folder to save text encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_text()

Type

str

query_dir

the folder to save query encoding memmap file returned by models.BaseModel.BaseSparseModel.encode_query()

Type

str

retrieval_result_path

the path of the final retrieval result file returned by models.BaseModel.BaseModel.retrieve()

Type

str

_rank

the current process ID

Type

int

_world_size

the number of all processes

Type

int

_distributed

if distributed training/evaluating is enabled

Type

bool

logger

the logger

Type

MasterLoger

code_dir

we separate the saving folder of generative model

Type

str

generative_index(loader_text: torch.utils.data.DataLoader)

Construct utils.index.TrieIndex.

index(**kwargs)

The index method. Subclass should override this function.

retrieve(**kwargs)

The retrieve method. Subclass should override this function.