Source code for bdschism.combine_nc

import xarray as xr
import os
import click
import glob
import re
import numpy as np


[docs] def validate_time_step_consistency(time_values, atol=1e-9, max_report=20): """Validate constant time step and report exact mismatch locations.""" times = np.asarray(time_values, dtype=float) if times.size < 2: print( "\tTime coordinate has fewer than 2 points; timestep consistency check skipped." ) return dt = np.diff(times) expected_dt = dt[0] bad_idx = np.where(~np.isclose(dt, expected_dt, atol=atol, rtol=0.0))[0] if bad_idx.size == 0: print(f"\tTimestep is consistent: dt={expected_dt}.") return report_lines = [ f"\tFound {bad_idx.size} inconsistent timestep(s). Expected dt={expected_dt}." ] for i in bad_idx[:max_report]: report_lines.append( f" idx {i}->{i+1}: t0={times[i]}, t1={times[i+1]}, dt={dt[i]}" ) if bad_idx.size > max_report: report_lines.append(f" ... and {bad_idx.size - max_report} more mismatch(es).") raise ValueError("\n".join(report_lines))
[docs] def get_selected_files(template, start_num, end_num): """Get a sorted list of files matching the template and within the specified number range.""" if "*" not in template: raise ValueError("template must include '*' wildcard, e.g., uv3d_*.th.nc") pattern = re.escape(template).replace(r"\*", r"(\d+)") matcher = re.compile(f"^{pattern}$") matched_files = glob.glob(template) selected = [] for path in matched_files: match = matcher.match(path) if not match: continue num = int(match.group(1)) if start_num <= num <= end_num: selected.append((num, path)) selected.sort(key=lambda item: item[0]) input_files = [path for _, path in selected] if not input_files: raise ValueError( f"No input files found for template '{template}' in range " f"[{start_num}, {end_num}]" ) return input_files
[docs] def combine_nc(input_files, outfile): """Combines multiple NetCDF files (e.g., out2d_1.nc, out2d_2.nc, etc.) into a single NetCDF file along the time dimension.""" # Load and combine the NetCDF files, dropping the first time slice # for every dataset except the first one to avoid duplicate boundaries. datasets = [xr.open_dataset(f, decode_times=False) for f in input_files] concat_datasets = [] for idx, dataset in enumerate(datasets): if idx == 0: concat_datasets.append(dataset) else: concat_datasets.append(dataset.isel(time=slice(1, None))) combined_ds = xr.concat(concat_datasets, dim="time") # Update the time coordinate with the provided times validate_time_step_consistency(combined_ds["time"].values) # Save the combined dataset to a new NetCDF file out_dir = os.path.dirname(outfile) if out_dir: os.makedirs(out_dir, exist_ok=True) combined_ds.to_netcdf(outfile) for dataset in datasets: dataset.close() combined_ds.close()
[docs] def combine_uv3d( start_num, end_num, outfile, tmp_out_dir="./outputs.tropic/uv3d", ): template = f"{tmp_out_dir}/uv3d_*.th.nc" input_files = get_selected_files(template, start_num, end_num) combine_nc(input_files, outfile) print(f"\nOutfile written to : {outfile}")
[docs] def compare_dataarray_all_axes(validate_da, check_da, atol=1e-6, rtol=1e-6): """Check two DataArrays match on dims, coords, shape, and values.""" if validate_da.dims != check_da.dims: raise ValueError( f"Dimension name/order mismatch: {validate_da.dims} != {check_da.dims}" ) if validate_da.shape != check_da.shape: raise ValueError(f"Shape mismatch: {validate_da.shape} != {check_da.shape}") for dim in validate_da.dims: if dim in validate_da.coords and dim in check_da.coords: vcoord = validate_da.coords[dim].values ccoord = check_da.coords[dim].values if vcoord.shape != ccoord.shape: raise ValueError( f"Coordinate shape mismatch on '{dim}': {vcoord.shape} != {ccoord.shape}" ) if not np.allclose(vcoord, ccoord, atol=atol, rtol=rtol): bad_idx = np.where(~np.isclose(vcoord, ccoord, atol=atol, rtol=rtol))[0] i0 = int(bad_idx[0]) raise ValueError( f"Coordinate mismatch on '{dim}' at index {i0}: " f"validate={vcoord[i0]}, check={ccoord[i0]}" ) vvals = validate_da.values cvals = check_da.values same_values = np.allclose(vvals, cvals, atol=atol, rtol=rtol, equal_nan=True) if not same_values: bad_locs = np.argwhere(~np.isclose(vvals, cvals, atol=atol, rtol=rtol, equal_nan=True)) first = tuple(int(i) for i in bad_locs[0]) raise ValueError( f"Data mismatch at index {first}: " f"validate={vvals[first]}, check={cvals[first]}" ) print("time_series matches across all axes and values.")
@click.command( help=( "Combine NetCDF files along the time dimension.\n\n" "\b\n" "Arguments:\n" " TEMPLATE Wildcard template for input files (must include '*').\n" " START First numeric index to include (inclusive).\n" " END Last numeric index to include (inclusive).\n\n" "\b\n" "Examples:\n" " bds combine_nc './outputs.tropic/uv3d/uv3d_*.th.nc' 1 10 ./uv3d_combined.th.nc\n" " bds combine_nc './outputs/out2d_*.nc' 1 5 ./out2d_combined.nc" ) ) @click.argument("template", type=str) @click.argument( "start", type=int, ) @click.argument( "end", type=int, ) @click.option( "-o", "--output", type=click.Path(), default="./uv3D.th.nc", help="Output file path for the combined NetCDF (default: ./uv3D.th.nc).", ) @click.option( "--tmp-out-dir", type=click.Path(), default=None, help=( "Directory containing uv3d intermediate files. " "Used only for uv3d mode; defaults to ./outputs.tropic/uv3d." ), ) @click.help_option("-h", "--help") def combine_nc_cli(template, start, end, output, tmp_out_dir): """Command line utility for combining NetCDF files. Example usage: bds combine_nc 'uv3d' 1 10 ./uv3d_combined.th.nc """ if "uv3d" in template.lower(): print("Combining uv3d files using combine_uv3d...") if tmp_out_dir is None: combine_uv3d(start, end, output) else: combine_uv3d(start, end, output, tmp_out_dir=tmp_out_dir) else: print( f"Combining generic NetCDF files using combine_nc and template: {template}..." ) input_files = get_selected_files(template, start, end) combine_nc(input_files, output) if __name__ == "__main__": combine_nc_cli() # os.chdir("/scratch/projects/dsp/updated_schism_202602/simulations/baseline_lhc_3") # start_num = 1 # end_num = 3 # outfile = "./uv3d_combined.th.nc" # combine_uv3d(start_num, end_num, outfile) # check_file = "./outputs.tropic/uv3D.th.nc" # check_ds = xr.open_dataset(check_file, decode_times=False) # validate_ds = xr.open_dataset(outfile, decode_times=False) # compare_dataarray_all_axes(validate_ds["time_series"], check_ds["time_series"]) # check_ds.close() # validate_ds.close()