Skip to content Skip to sidebar Skip to footer

How To Recover 3d Image From Its Patches In Python?

I have a 3D image with shape DxHxW. I was successful to extract the image into patches pdxphxpw(overlapping patches). For each patch, I do some processing. Now, I would like to gen

Solution 1:

This will do the reverse, however, since your patches overlap this will only be well-defined if their values agree where they overlap

def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
    out = np.zeros(out_shape, patches.dtype)
    patch_shape = patches.shape[-3:]
    patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
                                                  (out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                  (out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
    patches_6D[...] = patches.reshape(patches_6D.shape)
    returnout

Update: here is a safer version that averages overlapping pixels:

def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
    out = np.zeros(out_shape, patches.dtype)
    denom = np.zeros(out_shape, patches.dtype)
    patch_shape = patches.shape[-3:]
    patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
                                                  (out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                  (out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
    denom_6D = np.lib.stride_tricks.as_strided(denom, ((denom.shape[0] - patch_shape[0] + 1) // xstep, (denom.shape[1] - patch_shape[1] + 1) // ystep,
                                                  (denom.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                  (denom.strides[0] * xstep, denom.strides[1] * ystep,denom.strides[2] * zstep, denom.strides[0], denom.strides[1],denom.strides[2]))
    np.add.at(patches_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), patches.ravel())
    np.add.at(denom_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), 1)
    returnout/denom

Post a Comment for "How To Recover 3d Image From Its Patches In Python?"