#!/usr/bin/env python
import logging
import pickle
from typing import List
import nibabel as nib
import numpy as np
import scipy.sparse as sp
import talon
import talon.cli
DESCRIPTION = ("Use TALON to filter a tractogram with the Volume Fraction "
"forward model.")
def add_parser(subparsers):
parser = subparsers.add_parser(
'filter',
description=DESCRIPTION,
help='Filter a tractogram using TALON.')
parser.add_argument(
'in_tracks',
type=str,
help='Input tractogram file in RAS+ and mm space. '
'The streamline coordinate (0,0,0) refers to the center of the '
'voxel. '
'Must be in NiBabel-readable format (.trk or .tck).'
)
parser.add_argument(
'in_data',
type=str,
help='Input data to be fitted. Serves also as reference space for '
'tractogram. '
'Must be in NiBabel-readable format (.nii or .nii.gz).'
)
parser.add_argument(
'out_weights',
type=str,
help='Output text file containing the streamline weights.'
)
parser.add_argument(
'--operator-type',
type=str,
choices=['reference', 'fast', 'opencl'],
default='fast',
help='Type of operator to use. Default: `fast`.'
)
talon.cli.utils.add_ndir_to_input(parser)
solver_options = parser.add_argument_group('Solver options')
solver_options.add_argument(
'--allow-negative-x',
action='store_true',
help='Disables the non negativity constraint.'
)
solver_options.add_argument(
'--sigma',
type=float,
default=0.0,
metavar='value',
help='Sets the regularization scale parameter as in (Frigo, 2021). '
'The final value of lambda is `sigma*max(||At*data||/gwei)`, '
'where sigma is the passed parameter, `||At*data||` is the 2-norm '
'of the product between the transposed linear operator and the '
'data, and `gwei` is the vector of the weights associated to each '
'group of streamlines. Default: %(default)s.'
)
solver_options.add_argument(
'--streamline-assignment',
type=str,
metavar='file',
help='Activates the group sparsity regularization by specifying the '
'node assignments of each streamline to some parcellation. '
'Typically, this file is produced by the Mrtrix3 command '
'`tck2connectome` with the option `-out_assignment`. '
'The file is expected to be in text format with one row per '
'streamline. E.g., if the first row is [5, 14], the first '
'streamline will be bundled together with all the streamlines '
'corresponding rows having [5, 14] or [14, 5].'
)
solver_options.add_argument(
'--connectome',
type=str,
metavar='file',
help='Activates the FIT regularization by specifying the connectivity '
'matrix. '
'Each streamline bundle is associated to the entry in the '
'connectivity matrix corresponding to the region lables that it '
'connects. '
'E.g., the bundle connecting regions 5 and 14 is associated to '
'the entry [5, 14] of the connectivity matrix. '
'Notice that the first row and column correspond to the zero '
'label. '
'Must be used together with `--streamline-assignment`.'
)
solver_options.add_argument(
'--objective-relative-tolerance',
type=float,
default=1e-6,
metavar='value',
help='Sets relative tolerance on cost function. Default: %(default)s.'
)
solver_options.add_argument(
'--x-absolute-tolerance',
type=float,
default=1e-6,
metavar='value',
help='Sets absolute tolerance on variable. Default: %(default)s.'
)
solver_options.add_argument(
'--maxiter',
type=int,
default=1000,
metavar='count',
help='Sets maximum number of iterations. Default: %(default)s.'
)
parser.add_argument(
'--precomputed-indices-weights',
type=str,
nargs=2,
metavar=('file_idx', 'file_wei'),
help='Uses the indices and weights passed as input to build the '
'linear operator. '
'E.g. `--precomputed-indices-weights <indices>.npz '
'<weights>.npz`. '
'The two matrices must be defined on the same number of '
'directions as the ones that are used at the call of this script.'
)
operator_format = parser.add_mutually_exclusive_group(required=False)
operator_format.add_argument(
'--save-generators-indices-weights',
type=str,
nargs=3,
metavar=('file_gen', 'file_idx', 'file_wei'),
help='Saves the linear operator as three separate files '
'`<generators>.npy <indices>.npz <weights>.npz`. '
'All types of operator can be saved in this format.'
)
operator_format.add_argument(
'--save-operator-pickle',
type=str,
metavar='file',
help='Saves the linear operator with pickle. Only available when '
'--operator-type is set to `reference` or `fast`.'
)
talon.cli.utils.add_verbosity_and_force_to_parser(parser)
parser.set_defaults(func=run)
return parser
def load_operator_and_data(indices: sp.coo_matrix, weights: sp.coo_matrix,
fdata: str, operator_type: str, ndir: int) -> (
talon.core.LinearOperator, np.ndarray):
data = nib.load(fdata).get_fdata().ravel()
logging.info('Loaded data')
# mask data
g = np.ones((ndir, 1), dtype=np.float64)
if not data.size == weights.shape[0]:
raise ValueError('Data volume and weights matrix are not compatible.')
op = talon.operator(g, indices, weights, operator_type='reference')
data = talon.utils.mask_data(data, op)
logging.info('Preprocessed data')
del op
del g
# define operator
if operator_type == 'opencl':
generators = np.zeros((ndir, 4), dtype=talon.core.DATATYPE)
generators[:, 0] = 1.0
data = np.kron(data[:, None], np.array([[1., 0, 0, 0]]).T).squeeze()
else:
generators = np.ones((ndir, 1), dtype=talon.core.DATATYPE)
logging.info(f'Loading operator of type: {operator_type}')
op = talon.operator(generators, indices, weights,
operator_type=operator_type)
logging.info('Loaded operator')
return op, data
def compute_indices_and_generators(in_tracks: str, in_data: str, ndir: int):
logging.info('Voxelizing tractogram')
img = nib.load(in_data)
affine = np.linalg.inv(img.affine)
image_shape = img.shape
vertices = talon.utils.directions(number_of_points=ndir)
tracto = nib.streamlines.load(in_tracks)
logging.info('Changing streamlines coordinates to voxel space')
tracto.tractogram.apply_affine(affine)
streamlines = tracto.streamlines
return talon.voxelize(streamlines, vertices, image_shape)
[docs]def run(in_tracks: str, in_data: str, out_weights: str, force: bool, ndir: int,
precomputed_indices_weights: List[str],
save_generators_indices_weights: List[str],
save_operator_pickle: str, operator_type: str,
streamline_assignment: str, connectome: str, sigma: float,
allow_negative_x: bool, maxiter: int,
objective_relative_tolerance: float, x_absolute_tolerance: float,
**kwargs):
"""
Args:
in_tracks: str
Input tractogram file in RAS+ and mm space. The streamline
coordinate (0,0,0) refers to the center of the voxel. Must be in
NiBabel-readable format (.trk or .tck).
in_data: str
Input data to be fitted. Serves also as reference space for
tractogram. Must be in NiBabel-readable format (.nii or .nii.gz).
out_weights: str
Output text file containing the streamline weights.
force: bool
True if the file can be overwritten, False otherwise.
ndir: int
Number of directions for the voxelization.
precomputed_indices_weights: List
Uses the indices and weights passed as input to build the linear
operator. The two matrices must be defined on the same number of
directions (ndir) as the ones that are used at the call of this
script.
save_generators_indices_weights: List
Saves the linear operator as three separate files.
save_operator_pickle: str
Saves the linear operator with pickle. Only available when
`operator_type` is 'fast' or 'reference'.
operator_type: str
Type of operator to use. Default: `fast`. Choiches: 'reference',
'fast', 'opencl'.
streamline_assignment: str
Path to the file whose rows contain the assignment of each
streamline. E.g., if the n-th row is '5 17', the n-th streamline is
assigned to regions 5 and 17. The region labels must be integer
values and separated by a blank space. Lines starting with # are
skipped.
connectome: str
Path to the connectivity matrix to be employed in txt format. The
first row and column correspond to the zero label.
sigma: float
Sets the regularization scale parameter as in (Frigo, 2021). The
final value of lambda is `sigma*max(||At*data||/gwei)`, where sigma
is the passed parameter, `||At*data||` is the 2-norm of the product
between the transposed linear operator and the data, and `gwei` is
the vector of the weights associated to each group of streamlines.
allow_negative_x: bool
Disables the non negativity constraint.
maxiter: int
Sets maximum number of iterations. Default: 1000.
objective_relative_tolerance: float
Sets relative tolerance on cost function. Default: 1e-6.
x_absolute_tolerance: float
Sets absolute tolerance on variable. Default: 1e-6.
"""
files_to_check = [out_weights]
if save_generators_indices_weights is not None:
files_to_check.extend(save_generators_indices_weights)
if save_operator_pickle is not None:
files_to_check.append(save_operator_pickle)
for f in files_to_check:
talon.cli.utils.check_can_write_file(f, force=force)
if precomputed_indices_weights is not None:
i = sp.load_npz(precomputed_indices_weights[0])
w = sp.load_npz(precomputed_indices_weights[1])
else: # Voxelize
i, w = compute_indices_and_generators(in_tracks, in_data, ndir=ndir)
# Go back to filter
op, data = load_operator_and_data(i, w, in_data, operator_type, ndir)
llambda = 0.0
groups = None
weights = None
if streamline_assignment is not None:
mapping = talon.cli.utils.assignment_to_mapping(streamline_assignment)
c = None
if connectome is not None:
c = np.loadtxt(connectome)
groups, weights = talon.cli.utils.mapping_to_groups_weights(mapping, c)
logging.info(f'Number of streamline bundles: {len(groups)}')
llambda = np.linalg.norm(op.T @ data, 2)
llambda *= np.max(1 / weights)
llambda *= sigma
logging.info(f'Regularization parameter: {llambda}')
regterm = talon.regularization(
non_negativity=np.logical_not(allow_negative_x),
regularization_parameter=llambda,
groups=groups,
weights=weights
)
logging.info(f'Non-negativity: {regterm.non_negativity}')
x0 = np.zeros(op.shape[1], dtype=talon.core.DATATYPE)
converged = False
total_nit = 0
sol = None
while not converged:
max_nit = np.min([maxiter - total_nit, 25])
sol = talon.solve(op, data,
reg_term=regterm,
cost_reltol=objective_relative_tolerance,
x_abstol=x_absolute_tolerance,
max_nit=max_nit,
x0=x0,
verbose='NONE')
total_nit += sol.nit
logging.info(f'Iteration: {total_nit} ~ '
f'1/2||Ax-y||2: {sol.fun[0]} ~ '
f'Omega: {sol.fun[1]} ~ '
f'Total: {np.sum(sol.fun)}')
x0 = sol.x
if sol.success:
converged = True
# save result
logging.info(f'Saving result in {out_weights}')
np.savetxt(out_weights, sol.x)
# save dictionary
if save_generators_indices_weights is not None:
logging.info(
'Saving linear operator in {} {} {}.'.format(
*save_generators_indices_weights))
np.save(save_generators_indices_weights[0], op.generators)
sp.save_npz(save_generators_indices_weights[1],
op._indices_of_generators)
sp.save_npz(save_generators_indices_weights[2], op._weights)
elif save_operator_pickle:
if operator_type in ['reference', 'fast']:
logging.info(f'Saving linear operator in {save_operator_pickle}.')
with open(save_operator_pickle, 'wb') as f:
pickle.dump(op, f)
else:
raise ValueError('Operator type does not allow to save in pickle '
'format.')
if precomputed_indices_weights is None:
logging.info('Cleaning temporary files.')