diff --git a/containerd/containerd.go b/containerd/containerd.go index 0728cad..4ab2d8c 100644 --- a/containerd/containerd.go +++ b/containerd/containerd.go @@ -1,6 +1,8 @@ package containerd import ( + "fmt" + "github.com/containerd/containerd" "github.com/containerd/containerd/cio" "github.com/containerd/containerd/oci" @@ -18,13 +20,40 @@ func (d *Driver) pullImage(imageName string) (containerd.Image, error) { return d.client.Pull(d.ctxContainerd, imageName, containerd.WithPullUnpack) } -func (d *Driver) createContainer(image containerd.Image, containerName, containerSnapshotName, containerdRuntime string, env []string) (containerd.Container, error) { +func (d *Driver) createContainer(image containerd.Image, containerName, containerSnapshotName, containerdRuntime string, env []string, config *TaskConfig) (containerd.Container, error) { + if config.Command == "" && len(config.Args) > 0 { + return nil, fmt.Errorf("Command is empty. Cannot set --args without --command.") + } + + var args []string + if config.Command != "" { + args = append(args, config.Command) + } + + if len(config.Args) > 0 { + args = append(args, config.Args...) + } + + var opts []oci.SpecOpts + + opts = append(opts, oci.WithImageConfigArgs(image, args)) + + if len(config.CapAdd) > 0 { + opts = append(opts, oci.WithCapabilities(config.CapAdd)) + } + + if len(config.CapDrop) > 0 { + opts = append(opts, oci.WithDroppedCapabilities(config.CapDrop)) + } + + opts = append(opts, oci.WithEnv(env)) + return d.client.NewContainer( d.ctxContainerd, containerName, containerd.WithRuntime(containerdRuntime, nil), containerd.WithNewSnapshot(containerSnapshotName, image), - containerd.WithNewSpec(oci.WithImageConfig(image), oci.WithEnv(env)), + containerd.WithNewSpec(opts...), ) } diff --git a/containerd/driver.go b/containerd/driver.go index a946f80..e7094fe 100644 --- a/containerd/driver.go +++ b/containerd/driver.go @@ -69,7 +69,11 @@ var ( // this is used to validate the configuration specified for the plugin // when a job is submitted. taskConfigSpec = hclspec.NewObject(map[string]*hclspec.Spec{ - "image": hclspec.NewAttr("image", "string", true), + "image": hclspec.NewAttr("image", "string", true), + "command": hclspec.NewAttr("command", "string", false), + "args": hclspec.NewAttr("args", "list(string)", false), + "cap_add": hclspec.NewAttr("cap_add", "list(string)", false), + "cap_drop": hclspec.NewAttr("cap_drop", "list(string)", false), }) // capabilities indicates what optional features this driver supports @@ -92,7 +96,11 @@ type Config struct { // TaskConfig contains configuration information for a task that runs with // this plugin type TaskConfig struct { - Image string `codec:"image"` + Image string `codec:"image"` + Command string `codec:"command"` + Args []string `codec:"args"` + CapAdd []string `codec:"cap_add"` + CapDrop []string `codec:"cap_drop"` } // TaskState is the runtime state which is encoded in the handle returned to @@ -302,7 +310,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive } containerSnapshotName := fmt.Sprintf("%s-snapshot", containerName) - container, err := d.createContainer(image, containerName, containerSnapshotName, d.config.ContainerdRuntime, env) + container, err := d.createContainer(image, containerName, containerSnapshotName, d.config.ContainerdRuntime, env, &driverConfig) if err != nil { return nil, nil, fmt.Errorf("Error in creating container: %v", err) }