Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 137 additions & 124 deletions STalign/STalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None,
a=500.0,p=2.0,expand=2.0,nt=3,
niter=5000,diffeo_start=0, epL=2e-8, epT=2e-1, epV=2e3,
sigmaM=1.0,sigmaB=2.0,sigmaA=5.0,sigmaR=5e5,sigmaP=2e1,
device='cpu',dtype=torch.float64, muB=None, muA=None):
device='cpu',dtype=torch.float64, muB=None, muA=None, display=True):
''' Run LDDMM between a pair of images.

This jointly estimates an affine transform A, and a diffeomorphism phi.
Expand Down Expand Up @@ -995,6 +995,9 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None,
If the target is a grayscale image, this should be a tensor of size 1.
muB: torch tensor whose dimension is the same as the target image
Defaults to None, which means we estimate this. If you provide a value, we will not estimate it.
display: binary
Defaults to True
Decides if the plots of the function will be shown

Returns a dictionary
-------
Expand All @@ -1011,6 +1014,8 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None,
Resulting weight 2D array (background)
'WA': torch tensor
Resulting weight 2D array (artifact)
'Errors': list
List of the progresion of the errors in algingment
}

'''
Expand Down Expand Up @@ -1089,10 +1094,11 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None,
#ax.imshow(K[0].cpu())
DV = torch.prod(dv)
Ki = torch.fft.ifftn(K).real
fig,ax = plt.subplots()
ax.imshow(Ki.clone().detach().cpu().numpy(),vmin=0.0,extent=extentV)
ax.set_title('smoothing kernel')
fig.canvas.draw()
if display:
fig,ax = plt.subplots()
ax.imshow(Ki.clone().detach().cpu().numpy(),vmin=0.0,extent=extentV)
ax.set_title('smoothing kernel')
fig.canvas.draw()


# nt = 3
Expand Down Expand Up @@ -1137,9 +1143,10 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None,
else:
estimate_muB = False

fig,ax = plt.subplots(2,3)
ax = ax.ravel()
figE,axE = plt.subplots(1,3)
if display:
fig,ax = plt.subplots(2,3)
ax = ax.ravel()
figE,axE = plt.subplots(1,3)
Esave = []

try:
Expand Down Expand Up @@ -1253,65 +1260,67 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None,

# draw
if not it%10:
ax[0].cla()
ax[0].imshow( ((AI-torch.amin(AI,(1,2))[...,None,None])/(torch.amax(AI,(1,2))-torch.amin(AI,(1,2)))[...,None,None]).permute(1,2,0).clone().detach().cpu(),extent=extentJ)
ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[0].set_title('space tformed source')

ax[1].cla()
ax[1].imshow(clip(fAI.permute(1,2,0).clone().detach()/torch.max(J).item()).cpu(),extent=extentJ)
ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[1].set_title('contrast tformed source')

ax[5].cla()
ax[5].imshow(clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,0).clone().detach().cpu()*0.5+0.5,extent=extentJ)
ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[5].set_title('Error')

ax[2].cla()
ax[2].imshow(J.permute(1,2,0).cpu()/torch.max(J).item(),extent=extentJ)
ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[2].set_title('Target')

ax[4].cla()
ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu(),extent=extentJ)
ax[4].set_title('Weights')


toshow = v[0].clone().detach().cpu()
toshow /= torch.max(torch.abs(toshow))
toshow = toshow*0.5+0.5
toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1)
ax[3].cla()
ax[3].imshow(clip(toshow),extent=extentV)
ax[3].set_title('velocity')

axE[0].cla()
axE[0].plot(Esave)
axE[0].legend(['E','EM','ER','EP'])
axE[0].set_yscale('log')
axE[1].cla()
axE[1].plot([e[:2] for e in Esave])
axE[1].legend(['E','EM'])
axE[1].set_yscale('log')
axE[2].cla()
axE[2].plot([e[2] for e in Esave])
axE[2].legend(['ER'])
axE[2].set_yscale('log')



fig.canvas.draw()
figE.canvas.draw()
if display:
ax[0].cla()
ax[0].imshow( ((AI-torch.amin(AI,(1,2))[...,None,None])/(torch.amax(AI,(1,2))-torch.amin(AI,(1,2)))[...,None,None]).permute(1,2,0).clone().detach().cpu(),extent=extentJ)
ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[0].set_title('space tformed source')

ax[1].cla()
ax[1].imshow(clip(fAI.permute(1,2,0).clone().detach()/torch.max(J).item()).cpu(),extent=extentJ)
ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[1].set_title('contrast tformed source')

ax[5].cla()
ax[5].imshow(clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,0).clone().detach().cpu()*0.5+0.5,extent=extentJ)
ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[5].set_title('Error')

ax[2].cla()
ax[2].imshow(J.permute(1,2,0).cpu()/torch.max(J).item(),extent=extentJ)
ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[2].set_title('Target')

ax[4].cla()
ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu(),extent=extentJ)
ax[4].set_title('Weights')


toshow = v[0].clone().detach().cpu()
toshow /= torch.max(torch.abs(toshow))
toshow = toshow*0.5+0.5
toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1)
ax[3].cla()
ax[3].imshow(clip(toshow),extent=extentV)
ax[3].set_title('velocity')

axE[0].cla()
axE[0].plot(Esave)
axE[0].legend(['E','EM','ER','EP'])
axE[0].set_yscale('log')
axE[1].cla()
axE[1].plot([e[:2] for e in Esave])
axE[1].legend(['E','EM'])
axE[1].set_yscale('log')
axE[2].cla()
axE[2].plot([e[2] for e in Esave])
axE[2].legend(['ER'])
axE[2].set_yscale('log')



fig.canvas.draw()
figE.canvas.draw()

return {
'A': A.clone().detach(),
'v': v.clone().detach(),
'xv': xv,
'WM': WM.clone().detach(),
'WB': WB.clone().detach(),
'WA': WA.clone().detach()
'WA': WA.clone().detach(),
'Errors': Esave,
}


Expand All @@ -1320,7 +1329,7 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None,
a=500.0,p=2.0,expand=1.25,nt=3,
niter=5000,diffeo_start=0, epL=1e-6, epT=1e1, epV=1e3,
sigmaM=1.0,sigmaB=2.0,sigmaA=5.0,sigmaR=1e8,sigmaP=2e1,
device='cpu',dtype=torch.float64, muA=None, muB = None):
device='cpu',dtype=torch.float64, muA=None, muB = None, display=True):
''' LDDMM for 3D to 2D slice mapping.

muA: torch tensor whose dimension is the same as the target image
Expand Down Expand Up @@ -1395,10 +1404,11 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None,
#ax.imshow(K[0].cpu())
DV = torch.prod(dv)
Ki = torch.fft.ifftn(K).real
fig,ax = plt.subplots()
ax.imshow(Ki[Ki.shape[0]//2].clone().detach().cpu().numpy(),vmin=0.0,extent=extentV)
ax.set_title('smoothing kernel')
fig.canvas.draw()
if display:
fig,ax = plt.subplots()
ax.imshow(Ki[Ki.shape[0]//2].clone().detach().cpu().numpy(),vmin=0.0,extent=extentV)
ax.set_title('smoothing kernel')
fig.canvas.draw()

# steps
epL = torch.tensor(epL,device=device,dtype=dtype)
Expand Down Expand Up @@ -1442,10 +1452,11 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None,
'''

# a figure
fig,ax = plt.subplots(2,3)
ax = ax.ravel()
figE,axE = plt.subplots(1,3)
axE = axE.ravel()
if display:
fig,ax = plt.subplots(2,3)
ax = ax.ravel()
figE,axE = plt.subplots(1,3)
axE = axE.ravel()
Esave = []
# zero gradients
try:
Expand Down Expand Up @@ -1564,60 +1575,61 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None,

# draw
if not it%10:
ax[0].cla()
Ishow = ((AI-torch.amin(AI,(1,2,3))[...,None,None])/(torch.amax(AI,(1,2,3))-torch.amin(AI,(1,2,3)))[...,None,None,None]).permute(1,2,3,0).clone().detach().cpu()
ax[0].imshow( Ishow[0,...,0] ,extent=extentJ)
#ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[0].set_title('space tformed source')

ax[1].cla()
Ishow = clip(fAI.permute(1,2,3,0).clone().detach()/torch.max(J).item()).cpu()
ax[1].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1)
#ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[1].set_title('contrast tformed source')

ax[5].cla()
Ishow = clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,3,0).clone().detach().cpu()*0.5+0.5
ax[5].imshow(Ishow[0,...,0],extent=extentJ)
#ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
#ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[5].set_title('Error')

ax[2].cla()
Ishow = J.permute(1,2,3,0).cpu()/torch.max(J).item()
ax[2].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1)
#ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[2].set_title('Target')

ax[4].cla()
ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu()[0],extent=extentJ)
ax[4].set_title('Weights')


toshow = v[0].clone().detach().cpu() # initial velocity, components are rgb
toshow /= torch.max(torch.abs(toshow))
toshow = toshow*0.5+0.5
#toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1)
ax[3].cla()
ax[3].imshow(clip(toshow)[toshow.shape[0]//2],extent=extentV)
ax[3].set_title('velocity')

axE[0].cla()
axE[0].plot(Esave)
axE[0].legend(['E','EM','ER','EP'])
axE[0].set_yscale('log')
axE[1].cla()
axE[1].plot([e[:2] for e in Esave])
axE[1].legend(['E','EM'])
axE[1].set_yscale('log')
axE[2].cla()
axE[2].plot([e[2] for e in Esave])
axE[2].legend(['ER'])
axE[2].set_yscale('log')


fig.canvas.draw()
figE.canvas.draw()
if display:
ax[0].cla()
Ishow = ((AI-torch.amin(AI,(1,2,3))[...,None,None])/(torch.amax(AI,(1,2,3))-torch.amin(AI,(1,2,3)))[...,None,None,None]).permute(1,2,3,0).clone().detach().cpu()
ax[0].imshow( Ishow[0,...,0] ,extent=extentJ)
#ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[0].set_title('space tformed source')

ax[1].cla()
Ishow = clip(fAI.permute(1,2,3,0).clone().detach()/torch.max(J).item()).cpu()
ax[1].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1)
#ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
ax[1].set_title('contrast tformed source')

ax[5].cla()
Ishow = clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,3,0).clone().detach().cpu()*0.5+0.5
ax[5].imshow(Ishow[0,...,0],extent=extentJ)
#ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu())
#ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[5].set_title('Error')

ax[2].cla()
Ishow = J.permute(1,2,3,0).cpu()/torch.max(J).item()
ax[2].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1)
#ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu())
ax[2].set_title('Target')

ax[4].cla()
ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu()[0],extent=extentJ)
ax[4].set_title('Weights')


toshow = v[0].clone().detach().cpu() # initial velocity, components are rgb
toshow /= torch.max(torch.abs(toshow))
toshow = toshow*0.5+0.5
#toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1)
ax[3].cla()
ax[3].imshow(clip(toshow)[toshow.shape[0]//2],extent=extentV)
ax[3].set_title('velocity')

axE[0].cla()
axE[0].plot(Esave)
axE[0].legend(['E','EM','ER','EP'])
axE[0].set_yscale('log')
axE[1].cla()
axE[1].plot([e[:2] for e in Esave])
axE[1].legend(['E','EM'])
axE[1].set_yscale('log')
axE[2].cla()
axE[2].plot([e[2] for e in Esave])
axE[2].legend(['ER'])
axE[2].set_yscale('log')


fig.canvas.draw()
figE.canvas.draw()

return {
'A': A.clone().detach(),
Expand All @@ -1626,7 +1638,8 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None,
'WM': WM.clone().detach(),
'WB': WB.clone().detach(),
'WA': WA.clone().detach(),
'Xs': Xs.clone().detach()
'Xs': Xs.clone().detach(),
'Errors': Esave,
}


Expand Down