#!/usr/bin/env python3
# Copyright (C) 2020 ASTRON (Netherlands Institute for Radio Astronomy)
# SPDX-License-Identifier: GPL-3.0-or-later
import numpy as np
import matplotlib.pyplot as plt
import casacore.tables
import signal
import argparse
import time
import idg
import idg.util
# Enable interactive plotting and create figure to plot into
plt.ion()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,10))
# Set signal handler to exit when ctrl-c is pressed
def signal_handler(signal, frame):
exit()
signal.signal(signal.SIGINT, signal_handler)
######################################################################
# Command line argument parsing
######################################################################
parser = argparse.ArgumentParser(description='Run image domain gridding on a measurement set')
parser.add_argument(dest='msin', nargs=1, type=str,
help='path to measurement set')
parser.add_argument(dest='percentage',
nargs='?', type=int,
help='percentage of data to process',
default=100)
parser.add_argument('-c', '--column',
help='Data column used, such as DATA or CORRECTED_DATA (default: CORRECTED_DATA)',
required=False, default="CORRECTED_DATA")
parser.add_argument('--imagesize',
help='Image size (cell size / grid size)',
required=False, type=float, default=0.2)
modes = ["cpu", "hybrid", "gpu"]
parser.add_argument('--mode', choices=modes, required=False, help="Proxy to be used", default=modes[0])
args = parser.parse_args()
msin = args.msin[0]
percentage = args.percentage
image_size = args.imagesize
datacolumn = args.column
mode = args.mode
######################################################################
# Open measurementset
######################################################################
table = casacore.tables.taql(f"SELECT * FROM {msin} WHERE ANTENNA1 != ANTENNA2")
nr_times_ms = len(casacore.tables.taql(f"SELECT DISTINCT TIME FROM {msin}"))
print(f"nr_times_ms: {nr_times_ms}")
# Read parameters from measurementset
t_ant = casacore.tables.table(table.getkeyword("ANTENNA"))
t_spw = casacore.tables.table(table.getkeyword("SPECTRAL_WINDOW"))
frequencies = np.asarray(t_spw[0]['CHAN_FREQ'], dtype=np.float32)
nr_baselines = len(table.iter("TIME").next())
######################################################################
# Parameters
######################################################################
nr_stations = len(t_ant)
nr_channels = table[0][datacolumn].shape[0]
# Number of time steps per call to IDG.
# With 1 time step at a time, you can see the grid build up.
# This value is typically much larger, e.g. 128 or 256.
nr_timesteps = 1
# Number of A-terms in the time dimension per call to IDG.
# When no A-terms are needed, use nr_timeslots = 1,
# and set the # A-terms to identiy.
nr_timeslots = 1
nr_correlations = 4
grid_size = 512
subgrid_size = 32
kernel_size = 16
cell_size = image_size / grid_size
######################################################################
# Plot properties
######################################################################
colormap_grid = plt.get_cmap('hot')
colormap_img = plt.get_cmap('hot')
font_size = 16
######################################################################
# Initialize data
######################################################################
grid = idg.util.get_example_grid(nr_correlations, grid_size)
aterms = idg.util.get_identity_aterms(
nr_timeslots, nr_stations, subgrid_size, nr_correlations)
aterm_offsets = idg.util.get_example_aterm_offsets(
nr_timeslots, nr_timesteps)
# Initialize taper
taper = idg.util.get_example_taper(subgrid_size)
taper_grid = idg.util.get_identity_taper(grid_size)
######################################################################
# Initialize proxy
######################################################################
proxy = None
if mode == modes[0]:
proxy = idg.CPU.Optimized()
elif mode == modes[1]:
proxy = idg.HybridCUDA.GenericOptimized()
elif mode == modes[2]:
proxy = idg.CUDA.Generic()
w_step = 0.0
shift = np.zeros(2, np.float32)
proxy.set_grid(grid)
proxy.init_cache(subgrid_size, cell_size, w_step, shift)
######################################################################
# Process entire measurementset
######################################################################
nr_rows = table.nrows()
nr_rows_read = 0
nr_rows_per_batch = nr_baselines * nr_timesteps
nr_rows_to_process = nr_baselines * min(int( nr_rows * percentage / 100. ), nr_rows)
print(f"nr_rows: {nr_rows}")
# Initialize empty buffers
uvw = np.zeros(shape=(nr_baselines, nr_timesteps,3),
dtype=np.float32)
visibilities = np.zeros(shape=(nr_baselines, nr_timesteps, nr_channels,
nr_correlations),
dtype=np.complex64)
baselines = np.zeros(shape=(nr_baselines, 2),
dtype=np.intc)
img = np.zeros(shape=(nr_correlations, grid_size, grid_size),
dtype=np.complex64)
iteration = 0
print(f"nr_rows_read: {nr_rows_read}")
print(f"nr_rows_per_batch: {nr_rows_per_batch}")
print(f"nr_rows_to_process: {nr_rows_to_process}")
while (nr_rows_read + nr_rows_per_batch) < nr_rows_to_process:
# Reset buffers
uvw.fill(0)
visibilities.fill(0)
baselines.fill(0)
# Start timing
time_total = -time.time()
time_read = -time.time()
# Read nr_timesteps samples for all baselines including auto correlations
timestamp_block = table.getcol('TIME',
startrow = nr_rows_read,
nrow = nr_rows_per_batch)
antenna1_block = table.getcol('ANTENNA1',
startrow = nr_rows_read,
nrow = nr_rows_per_batch)
antenna2_block = table.getcol('ANTENNA2',
startrow = nr_rows_read,
nrow = nr_rows_per_batch)
uvw_block = table.getcol('UVW',
startrow = nr_rows_read,
nrow = nr_rows_per_batch)
vis_block = table.getcol(datacolumn,
startrow = nr_rows_read,
nrow = nr_rows_per_batch)
flags_block = table.getcol('FLAG',
startrow = nr_rows_read,
nrow = nr_rows_per_batch)
vis_block = vis_block * ~flags_block
vis_block[np.isnan(vis_block)] = 0
nr_rows_read += nr_rows_per_batch
time_read += time.time()
time_transpose = -time.time()
# Change precision
uvw_block = uvw_block.astype(np.float32)
vis_block = vis_block.astype(np.complex64)
# Remove autocorrelations
flags = antenna1_block != antenna2_block
antenna1_block = antenna1_block[flags]
antenna2_block = antenna2_block[flags]
uvw_block = uvw_block[flags]
vis_block = vis_block[flags]
# Reshape data
antenna1_block = np.reshape(antenna1_block,
newshape=(nr_timesteps, nr_baselines))
antenna2_block = np.reshape(antenna2_block,
newshape=(nr_timesteps, nr_baselines))
uvw_block = np.reshape(uvw_block,
newshape=(nr_timesteps, nr_baselines, 3))
vis_block = np.reshape(vis_block,
newshape=(nr_timesteps, nr_baselines,
nr_channels, nr_correlations))
# Transpose data
for t in range(nr_timesteps):
for bl in range(nr_baselines):
# Set baselines
antenna1 = antenna1_block[t][bl]
antenna2 = antenna2_block[t][bl]
baselines[bl] = (antenna1, antenna2)
# Set uvw
uvw[bl][t] = uvw_block[t][bl]
# Set visibilities
visibilities[bl][t] = vis_block[t][bl]
time_transpose += time.time()
# Grid visibilities
w_offset = 0.0
time_gridding = -time.time()
proxy.gridding(
kernel_size, frequencies, visibilities,
uvw, baselines, aterms, aterm_offsets, taper)
time_gridding += time.time()
# Get the raw grid
proxy.get_final_grid()
raw_grid = np.abs(grid[0,:,:])
# Compute fft over grid
time_fft = -time.time()
proxy.transform(idg.FourierDomainToImageDomain)
img_real = np.real(grid[0,:,:])
time_fft += time.time()
time_plot = -time.time()
# Remove taper from grid
img_real = img_real/taper_grid
# Crop image
img_crop = img_real[int(grid_size*0.1):int(grid_size*0.9),int(grid_size*0.1):int(grid_size*0.9)]
# Make first plot (raw grid)
ax1.set_xticks([])
ax1.set_yticks([])
ax1.imshow(np.log(raw_grid + 1), cmap=colormap_grid)
time1 = timestamp_block[0]
ax1.set_title("UV Data: %2.2i:%2.2i\n" % (np.mod(int(time1/3600 ),24), np.mod(int(time1/60),60)), fontsize=font_size)
# Make second plot (processed grid)
m = np.amax(img_crop)
ax2.imshow(img_crop, interpolation='nearest', clim = (-0.01*m, 0.3*m), cmap=colormap_img)
ax2.set_title("Sky image\n", fontsize=font_size)
ax2.set_xticks([])
ax2.set_yticks([])
# Draw figure
plt.pause(0.01)
time_plot += time.time()
# Print timings
time_total += time.time()
print(">>> Iteration %d" % iteration)
print("Runtime total: %5d ms" % (time_total*1000))
print("Runtime reading: %5d ms (%5.2f %%)" % (time_read*1000, 100.0 * time_read/time_total))
print("Runtime transpose: %5d ms (%5.2f %%)" % (time_transpose*1000, 100.0 * time_transpose/time_total))
print("Runtime gridding: %5d ms (%5.2f %%)" % (time_gridding*1000, 100.0 * time_gridding/time_total))
print("Runtime fft: %5d ms (%5.2f %%)" % (time_fft*1000, 100.0 * time_fft/time_total))
print("Runtime plot: %5d ms (%5.2f %%)" % (time_plot*1000, 100.0 * time_plot/time_total))
print()
iteration += 1
plt.show()
# Do not close window at the end?
plt.show(block=True)